[feat][assistant][I3T96H] add new dataset loading operator TedliumDataset
This commit is contained in:
parent
37cb0b7561
commit
34bffbf768
|
@ -112,6 +112,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
|
||||
|
@ -1448,6 +1449,34 @@ QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::ve
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
|
||||
const std::vector<char> &usage, const std::vector<char> &extensions,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
|
||||
CharToString(extensions), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
|
||||
const std::vector<char> &usage, const std::vector<char> &extensions,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
|
||||
CharToString(extensions), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
|
||||
const std::vector<char> &usage, const std::vector<char> &extensions,
|
||||
const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<TedliumNode>(CharToString(dataset_dir), CharToString(release), CharToString(usage),
|
||||
CharToString(extensions), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
TextFileDataset::TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
|
||||
|
||||
|
@ -400,6 +401,18 @@ PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
|
||||
"to create a TedliumNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string release, std::string usage,
|
||||
std::string extensions, py::handle sampler) {
|
||||
auto tedlium = std::make_shared<TedliumNode>(dataset_dir, release, usage, extensions,
|
||||
toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(tedlium->ValidateParams());
|
||||
return tedlium;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
||||
"to create a TextFileNode")
|
||||
|
|
|
@ -29,6 +29,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
random_data_op.cc
|
||||
sbu_op.cc
|
||||
speech_commands_op.cc
|
||||
tedlium_op.cc
|
||||
text_file_op.cc
|
||||
usps_op.cc
|
||||
yes_no_op.cc
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/tedlium_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TedliumOp::TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||
const std::string &extensions, int32_t num_parallel_workers,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler, int32_t queue_size)
|
||||
: MappableLeafOp(num_parallel_workers, queue_size, std::move(sampler)),
|
||||
dataset_dir_(dataset_dir),
|
||||
release_(release),
|
||||
usage_(usage),
|
||||
extensions_(extensions),
|
||||
data_schema_(std::move(data_schema)),
|
||||
audio_files_({}),
|
||||
usage_list_({}) {}
|
||||
|
||||
void TedliumOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op.
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff.
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nTedliumOp directory: " << dataset_dir_;
|
||||
}
|
||||
}
|
||||
|
||||
Status TedliumOp::PrepareData() {
|
||||
auto real_path = FileUtils::GetRealPath(dataset_dir_.c_str());
|
||||
if (!real_path.has_value()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + dataset_dir_);
|
||||
}
|
||||
Path root_folder(real_path.value());
|
||||
|
||||
if (release_ == "release1" || release_ == "release2") {
|
||||
if (usage_ == "train" || usage_ == "test" || usage_ == "dev") {
|
||||
usage_list_.push_back(usage_);
|
||||
} else if (usage_ == "all") {
|
||||
usage_list_ = {"train", "test", "dev"};
|
||||
}
|
||||
for (int32_t i = 0; i < usage_list_.size(); ++i) {
|
||||
Path stm_folder = root_folder / usage_list_[i] / "stm";
|
||||
RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, usage_list_[i]));
|
||||
}
|
||||
} else if (release_ == "release3") {
|
||||
if (usage_ == "all") {
|
||||
Path stm_folder = root_folder / "data" / "stm";
|
||||
RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, "data"));
|
||||
}
|
||||
}
|
||||
std::sort(audio_files_.begin(), audio_files_.end());
|
||||
num_rows_ = audio_files_.size();
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, no valid data matching the dataset API TedliumDataset. Please check file path or dataset API.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumOp::ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage) {
|
||||
Path dir(stm_folder);
|
||||
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&dir);
|
||||
if (!dir.Exists() || dirItr == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open folder: " + dir.ToString());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Tedlium " + release_ + " stm folder Path found: " << dir << ".";
|
||||
while (dirItr->HasNext()) {
|
||||
Path file = dirItr->Next();
|
||||
if (file.Extension() == ".stm") {
|
||||
std::ifstream handle(file.ToString());
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file.ToString());
|
||||
}
|
||||
std::string line;
|
||||
int32_t numline = 0;
|
||||
while (getline(handle, line)) {
|
||||
std::string filename = line.substr(0, line.find(" "));
|
||||
std::stringstream ss;
|
||||
ss << numline;
|
||||
audio_files_.push_back({ss.str(), filename, release_usage});
|
||||
++numline;
|
||||
}
|
||||
handle.close();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumOp::ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id,
|
||||
std::string *start_time, std::string *end_time, std::string *identifier,
|
||||
std::string *transcript) {
|
||||
std::ifstream handle(file_stm_path.ToString().c_str());
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file_stm_path.ToString());
|
||||
}
|
||||
std::string line;
|
||||
int32_t i = 0;
|
||||
while (i <= row_line && getline(handle, line)) {
|
||||
++i;
|
||||
}
|
||||
handle.close();
|
||||
std::vector<std::string> temp;
|
||||
i = 0;
|
||||
const int32_t data_stm_number = 7;
|
||||
// There are seven pieces of data in each row, which need to be read out and stored
|
||||
// with a space as a separator.
|
||||
// Talk_id, _, speaker_id, start_time, end_time, identifier, transcript.
|
||||
// "_" is the data we don't need.
|
||||
while (i < data_stm_number - 1) {
|
||||
std::string s = line.substr(0, line.find(" "));
|
||||
temp.push_back(s);
|
||||
line.erase(0, line.find(" ") + 1); // to delete space, so use s.find(" ") + 1.
|
||||
++i;
|
||||
}
|
||||
temp.push_back(line);
|
||||
if (temp.size() != data_stm_number) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, stm data was broken.");
|
||||
}
|
||||
|
||||
const int32_t talk_id_num = 0, speaker_id_num = 2, start_time_num = 3, end_time_num = 4, identifier_num = 5,
|
||||
transcript_num = 6;
|
||||
*talk_id = temp[talk_id_num];
|
||||
// temp[1] is "_", which is the data we don't need.
|
||||
*speaker_id = temp[speaker_id_num];
|
||||
*start_time = temp[start_time_num];
|
||||
*end_time = temp[end_time_num];
|
||||
*identifier = temp[identifier_num];
|
||||
*transcript = temp[transcript_num];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumOp::ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate,
|
||||
std::vector<float> *result) {
|
||||
std::ifstream handle(file_sph_path.ToString().c_str(), std::ios::in | std::ios::binary);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file_sph_path.ToString());
|
||||
}
|
||||
|
||||
char head[1024];
|
||||
handle.read(head, sizeof(head));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(),
|
||||
"Invalid data, failed to read head part from sph file: " + file_sph_path.ToString() +
|
||||
", re-download dataset(make sure the data is true).");
|
||||
std::vector<std::string> vec;
|
||||
for (int32_t i = 0, j = 0; i < strlen(head); ++i) {
|
||||
if (head[i] == '\n' || head[i] == ' ') {
|
||||
while (head[i + 1] == ' ') {
|
||||
i++;
|
||||
}
|
||||
std::string strTemp(head + j, i - j);
|
||||
vec.push_back(strTemp);
|
||||
j = i + 1;
|
||||
}
|
||||
}
|
||||
const int32_t dataToBytes = 2;
|
||||
for (int32_t i = 0; i < vec.size(); ++i) {
|
||||
if (vec[i] == "sample_rate") {
|
||||
*sample_rate = atoi(vec[i + dataToBytes].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
int32_t start = static_cast<int32_t>(start_time * (*sample_rate));
|
||||
int32_t end = static_cast<int32_t>(end_time * (*sample_rate));
|
||||
const int32_t size = (end - start);
|
||||
std::vector<char> temp(size * dataToBytes);
|
||||
handle.seekg(start, std::ios::beg);
|
||||
int32_t j = 0;
|
||||
char c;
|
||||
while (j < size * dataToBytes) {
|
||||
handle.read(&c, 1);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(),
|
||||
"Invalid data, failed to read data part from sph file: " + file_sph_path.ToString() +
|
||||
", re-download dataset(make sure the data is true).");
|
||||
temp.push_back(c);
|
||||
++j;
|
||||
}
|
||||
|
||||
const float kMaxVal = 32767.0;
|
||||
for (int32_t i = 0; i < size; ++i) {
|
||||
char bh = temp[2 * i];
|
||||
char bl = temp[2 * i + 1];
|
||||
// SPH aduio files is big-endian, so we should convert the two bytes of data into int16_t based
|
||||
// on the high 8 bits and the low 8 bits.
|
||||
int16_t s = static_cast<int16_t>(((bh & 0x00FF) << 8) | (bl & 0x00FF));
|
||||
// Data normalization: Convert the data from the interval [-32768,32767] to the interval [-1,1].
|
||||
double t = s / kMaxVal;
|
||||
(*result).push_back(t);
|
||||
}
|
||||
handle.close();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumOp::LoadTensorRow(row_id_type row_id, TensorRow *row) {
|
||||
int32_t row_line = atoi(audio_files_[row_id][0].c_str());
|
||||
std::string file_name = audio_files_[row_id][1];
|
||||
std::string file_usage_or3_none_ = audio_files_[row_id][2];
|
||||
Path dir_path(dataset_dir_);
|
||||
Path file_stm_path = dir_path / file_usage_or3_none_ / "stm" / (file_name + ".stm");
|
||||
Path file_sph_path = dir_path / file_usage_or3_none_ / "sph" / (file_name + extensions_);
|
||||
std::string talk_id, speaker_id, start_time, end_time, identifier, transcript;
|
||||
std::vector<float> result;
|
||||
int32_t sample_rate;
|
||||
RETURN_IF_NOT_OK(
|
||||
ReadStm(file_stm_path, row_line, &talk_id, &speaker_id, &start_time, &end_time, &identifier, &transcript));
|
||||
RETURN_IF_NOT_OK(ReadSph(file_sph_path, atof(start_time.c_str()), atof(end_time.c_str()), &sample_rate, &result));
|
||||
|
||||
std::shared_ptr<Tensor> sample_rate_tensor, talk_id_tensor, speaker_id_tensor, identifier_tensor, transcript_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(talk_id, &talk_id_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(speaker_id, &speaker_id_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(identifier, &identifier_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(transcript, &transcript_tensor));
|
||||
|
||||
std::shared_ptr<Tensor> audio_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(result, &audio_tensor));
|
||||
RETURN_IF_NOT_OK(audio_tensor->ExpandDim(0));
|
||||
(*row) = TensorRow(row_id, {audio_tensor, sample_rate_tensor, transcript_tensor, talk_id_tensor, speaker_id_tensor,
|
||||
identifier_tensor});
|
||||
row->setPath({file_sph_path.ToString(), file_sph_path.ToString(), file_stm_path.ToString(), file_stm_path.ToString(),
|
||||
file_stm_path.ToString(), file_stm_path.ToString()});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumOp::CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||
const std::string &extensions, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from PrepareData()
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto new_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
// build a new unique schema object
|
||||
auto new_schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
new_schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
TensorShape sample_rate_scalar = TensorShape::CreateScalar();
|
||||
TensorShape trans_scalar = TensorShape::CreateScalar();
|
||||
TensorShape talk_id_scalar = TensorShape::CreateScalar();
|
||||
TensorShape speaker_id_scalar = TensorShape::CreateScalar();
|
||||
TensorShape identi_scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(
|
||||
ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar)));
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(
|
||||
ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar)));
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(
|
||||
ColDescriptor("talk_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &talk_id_scalar)));
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(
|
||||
ColDescriptor("speaker_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &speaker_id_scalar)));
|
||||
RETURN_IF_NOT_OK(new_schema->AddColumn(
|
||||
ColDescriptor("identifier", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &identi_scalar)));
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
std::shared_ptr<TedliumOp> op =
|
||||
std::make_shared<TedliumOp>(dataset_dir, release, usage, extensions, num_workers, std::move(new_schema),
|
||||
std::move(new_sampler), op_connect_size);
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
*count = static_cast<int64_t>(op->audio_files_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumOp::ComputeColMap() {
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class TedliumOp : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] dataset_dir Directory of tedlium dataset.
|
||||
/// \param[in] release Release of tedlium dataset, can be 'release1', 'release2' or 'release3'.
|
||||
/// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'.
|
||||
/// \param[in] extensions Extensions of the sph file, only '.sph' is valid.
|
||||
/// \param[in] num_parallel_workers Number of workers in parallel.
|
||||
/// \param[in] data_schema Schema of dataset.
|
||||
/// \param[in] sampler Sampler tells TedliumOp what to read.
|
||||
/// \param[in] queue_size Connector queue size.
|
||||
TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||
const std::string &extensions, int32_t num_parallel_workers, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t queue_size);
|
||||
|
||||
/// \brief Destructor.
|
||||
~TedliumOp() = default;
|
||||
|
||||
/// \brief A print method typically used for debugging.
|
||||
/// \param[in] out Out stream.
|
||||
/// \param[in] show_all Whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Op name getter.
|
||||
std::string Name() const override { return "TedliumOp"; }
|
||||
|
||||
/// \brief Initialize TedliumOp related var, calls the function to walk all files.
|
||||
/// \return Status The status code returned.
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief Function to count the number of samples in the TEDLIUM dataset.
|
||||
/// \param[in] dataset_dir Directory of tedlium dataset.
|
||||
/// \param[in] release Release of tedlium dataset.
|
||||
/// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'.
|
||||
/// \param[in] extensions Extensions of the sph file, only '.sph' is valid.
|
||||
/// \param[in] count Output arg that will hold the actual dataset size.
|
||||
/// \return Status The status code returned.
|
||||
static Status CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||
const std::string &extensions, int64_t *count);
|
||||
|
||||
private:
|
||||
/// \brief Read stm file.
|
||||
/// \param[in] file_stm_path The path of stm file.
|
||||
/// \param[in] row_line Which line of the file we need to read.
|
||||
/// \param[out] talk_id Talk identifier of the row_line in the file.
|
||||
/// \param[out] speaker_id Speaker identifier of the row_line in the file.
|
||||
/// \param[out] start_time Start time of the row_line in the file.
|
||||
/// \param[out] end_time End time of the row_line in the file.
|
||||
/// \param[out] identifier Identifier of the row_line in the file.
|
||||
/// \param[out] transcript Transcript of the row_line in the file.
|
||||
/// \return Status The status code returned.
|
||||
Status ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id,
|
||||
std::string *start_time, std::string *end_time, std::string *identifier, std::string *transcript);
|
||||
|
||||
/// \brief Read sph file.
|
||||
/// \param[in] file_sph_path The path of sph file.
|
||||
/// \param[in] start_time The start_time of row we need to use.
|
||||
/// \param[in] end_time The end_time of row we need to use.
|
||||
/// \param[out] sample_rate Sample rate of the row.
|
||||
/// \param[out] result Waveform result vector of the row.
|
||||
/// \return Status The status code returned.
|
||||
Status ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate,
|
||||
std::vector<float> *result);
|
||||
|
||||
/// \brief Read stm files according current release`s usage.
|
||||
/// \param[in] stm_folder The folder of stm files.
|
||||
/// \param[in] release_usage For release1 or release2, use usage_, for release3, "data".
|
||||
/// \return Status The status code returned.
|
||||
Status ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage);
|
||||
|
||||
/// \brief Load a tensor row according to a pair.
|
||||
/// \param[in] row_id Id of row need to load.
|
||||
/// \param[in] row Audio & label read into this tensor row.
|
||||
/// \return Status The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
/// \brief Private function for computing the assignment of the column name map.
|
||||
/// \return Status The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
const std::string release_;
|
||||
const std::string dataset_dir_;
|
||||
const std::string usage_;
|
||||
const std::string extensions_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
std::vector<std::vector<std::string> > audio_files_;
|
||||
std::vector<std::string> usage_list_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_
|
|
@ -103,6 +103,7 @@ constexpr char kQMnistNode[] = "QMnistDataset";
|
|||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
constexpr char kSBUNode[] = "SBUDataset";
|
||||
constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset";
|
||||
constexpr char kTedliumNode[] = "TedliumDataset";
|
||||
constexpr char kTextFileNode[] = "TextFileDataset";
|
||||
constexpr char kTFRecordNode[] = "TFRecordDataset";
|
||||
constexpr char kUSPSNode[] = "USPSDataset";
|
||||
|
|
|
@ -29,6 +29,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
random_node.cc
|
||||
sbu_node.cc
|
||||
speech_commands_node.cc
|
||||
tedlium_node.cc
|
||||
text_file_node.cc
|
||||
tf_record_node.cc
|
||||
usps_node.cc
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/tedlium_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor for TedliumNode.
|
||||
TedliumNode::TedliumNode(const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||
const std::string &extensions, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
release_(release),
|
||||
extensions_(extensions),
|
||||
usage_(usage),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> TedliumNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<TedliumNode>(dataset_dir_, release_, usage_, extensions_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void TedliumNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
|
||||
}
|
||||
|
||||
Status ValidateExtensionsParam(const std::string &dataset_name, const std::string &extensions) {
|
||||
if (extensions != ".sph") {
|
||||
std::string err_msg = dataset_name + ": extension " + extensions + " is not supported.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("TedliumNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("TedliumNode", release_, {"release1", "release2", "release3"}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateExtensionsParam("TedliumNode", extensions_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("TedliumNode", sampler_));
|
||||
|
||||
if (release_ == "release1" || release_ == "release2") {
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("TedliumNode", usage_, {"dev", "train", "test", "all"}));
|
||||
} else if (release_ == "release3") {
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("TedliumNode", usage_, {"all"}));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
TensorShape sample_rate_scalar = TensorShape::CreateScalar();
|
||||
TensorShape trans_scalar = TensorShape::CreateScalar();
|
||||
TensorShape talk_id_scalar = TensorShape::CreateScalar();
|
||||
TensorShape speaker_id_scalar = TensorShape::CreateScalar();
|
||||
TensorShape identi_scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("talk_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &talk_id_scalar)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("speaker_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &speaker_id_scalar)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("identifier", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &identi_scalar)));
|
||||
|
||||
// Argument that is not exposed to user in the API.
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto tedlium_op = std::make_shared<TedliumOp>(dataset_dir_, release_, usage_, extensions_, num_workers_,
|
||||
std::move(schema), std::move(sampler_rt), connector_que_size_);
|
||||
tedlium_op->SetTotalRepeats(GetTotalRepeats());
|
||||
tedlium_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(tedlium_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = 0, sample_size = 0;
|
||||
RETURN_IF_NOT_OK(TedliumOp::CountTotalRows(dataset_dir_, release_, usage_, extensions_, &num_rows));
|
||||
|
||||
// give sampler the total number of files and check if num_samples is smaller.
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
// We cache dataset size so as to not duplicated run.
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TedliumNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["release"] = release_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
args["extensions"] = extensions_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEDLIUM_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEDLIUM_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class TedliumNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
TedliumNode(const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||
const std::string &extensions, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~TedliumNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kTedliumNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id Shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
/// \brief Release getter.
|
||||
/// \return Release of the current node.
|
||||
const std::string &Release() const { return release_; }
|
||||
|
||||
/// \brief DatasetDir getter.
|
||||
/// \return DatasetDir of the current node.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
|
||||
/// \brief Usage getter.
|
||||
/// \return Usage of the current node.
|
||||
const std::string &Usage() const { return usage_; }
|
||||
|
||||
/// \brief Extensions getter.
|
||||
/// \return Extensions of the current node.
|
||||
const std::string &Extensions() const { return extensions_; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string release_;
|
||||
std::string usage_;
|
||||
std::string extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
}; // class TedliumNode
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEDLIUM_NODE_H_
|
|
@ -3614,6 +3614,109 @@ inline std::shared_ptr<SpeechCommandsDataset> SpeechCommands(const std::string &
|
|||
return std::make_shared<SpeechCommandsDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class TedliumDataset
|
||||
/// \brief A source dataset for reading and parsing tedlium dataset.
|
||||
class MS_API TedliumDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of TedliumDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
|
||||
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
|
||||
/// can be "train", "test" or "all".
|
||||
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
|
||||
const std::vector<char> &extensions, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of TedliumDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
|
||||
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
|
||||
/// can be "train", "test" or "all".
|
||||
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
|
||||
const std::vector<char> &extensions, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of TedliumDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
|
||||
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
|
||||
/// can be "train", "test" or "all".
|
||||
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
|
||||
const std::vector<char> &extensions, const std::reference_wrapper<Sampler> &samlper,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of TedliumDataset.
|
||||
~TedliumDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a TedliumDataset.
|
||||
/// \note The generated dataset has six columns ["waveform", "sample_rate", "transcript", "talk_id", "speaker_id",
|
||||
/// "identifier"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
|
||||
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
|
||||
/// can be "train", "test" or "all" (default = "all").
|
||||
/// \param[in] extensions The extensions of audio file. Only support ".sph" now (default = ".sph").
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the TedliumDataset.
|
||||
inline std::shared_ptr<TedliumDataset> MS_API Tedlium(
|
||||
const std::string &dataset_dir, const std::string &release, const std::string &usage = "all",
|
||||
const std::string &extensions = ".sph", const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
|
||||
StringToChar(extensions), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a TedliumDataset.
|
||||
/// \note The generated dataset has six columns ["waveform", "sample_rate","transcript", "talk_id", "speaker_id",
|
||||
/// "identifier"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
|
||||
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
|
||||
/// can be "train", "test" or "all".
|
||||
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the TedliumDataset.
|
||||
inline std::shared_ptr<TedliumDataset> MS_API Tedlium(const std::string &dataset_dir, const std::string &release,
|
||||
const std::string &usage, const std::string &extensions,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
|
||||
StringToChar(extensions), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a TedliumDataset.
|
||||
/// \note The generated dataset has six columns ["waveform", "sample_rate","transcript", "talk_id", "speaker_id",
|
||||
/// "identifier"].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] release Release of the dataset, can be "release1", "release2", "release3".
|
||||
/// \param[in] usage Part of dataset of TEDLIUM, for release3, only can be "all", for release1 and release2,
|
||||
/// can be "train", "test" or "all".
|
||||
/// \param[in] extensions The extensions of audio file. Only support ".sph" now.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the TedliumDataset.
|
||||
inline std::shared_ptr<TedliumDataset> MS_API Tedlium(const std::string &dataset_dir, const std::string &release,
|
||||
const std::string &usage, const std::string &extensions,
|
||||
Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
|
||||
StringToChar(extensions), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class TextFileDataset
|
||||
/// \brief A source dataset that reads and parses datasets stored on disk in text format.
|
||||
class MS_API TextFileDataset : public Dataset {
|
||||
|
|
|
@ -56,6 +56,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class RandomDataDataset;
|
||||
friend class SBUDataset;
|
||||
friend class SpeechCommandsDataset;
|
||||
friend class TedliumDataset;
|
||||
friend class TextFileDataset;
|
||||
friend class TFRecordDataset;
|
||||
friend class USPSDataset;
|
||||
|
|
|
@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
|
||||
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
|
||||
check_yes_no_dataset, check_speech_commands_dataset
|
||||
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size, get_auto_offload
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -8625,3 +8625,218 @@ class YesNoDataset(MappableDataset):
|
|||
|
||||
def parse(self, children=None):
|
||||
return cde.YesNoNode(self.dataset_dir, self.sampler)
|
||||
|
||||
|
||||
class TedliumDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Tedlium dataset.
|
||||
The columns of generated dataset depend on the source SPH files and the corresponding STM files.
|
||||
|
||||
The generated dataset has six columns :py:obj:`[waveform, sample_rate, transcript, talk_id, speaker_id,
|
||||
identifier]`.
|
||||
|
||||
The tensor of column :py:obj:`waveform` is of the float32 type.
|
||||
The tensor of column :py:obj:`sample_rate` is a scalar of the int32 type.
|
||||
The tensor of column :py:obj:`transcript` is a scalar of the string type.
|
||||
The tensor of column :py:obj:`talk_id` is a scalar of the string type.
|
||||
The tensor of column :py:obj:`speaker_id` is a scalar of the string type.
|
||||
The tensor of column :py:obj:`identifier` is a scalar of the string type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
release (str): Release of the dataset, can be "release1", "release2", "release3".
|
||||
usage (str, optional): Usage of this dataset.
|
||||
For release1 or release2, can be `train`, `test`, ` dev` or `all`.
|
||||
`train` will read from train samples,
|
||||
`test` will read from test samples,
|
||||
`dev` will read from dev samples,
|
||||
`all` will read from all samples.
|
||||
For release3, can only be "all", it will read from data samples (default=None, all samples).
|
||||
extensions (str): Extensions of the SPH files, only '.sph' is valid.
|
||||
(default=None, ".sph").
|
||||
num_samples (int, optional): The number of audio samples to be included in the dataset
|
||||
(default=None, all samples).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
|
||||
order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into (default=None). When this argument is specified, `num_samples` reflects
|
||||
the maximum sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir does not contain stm files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> tedlium_dataset_dir = "/path/to/tedlium_dataset_directory"
|
||||
>>> tedlium_dataset_release = ["release1", "release2", "release3"]
|
||||
>>>
|
||||
>>> # 1) Get all train samples from TEDLIUM_release1 dataset in sequence.
|
||||
>>> dataset = ds.TedliumDataset(dataset_dir=tedlium_dataset_dir, release=tedlium_dataset_release[0],
|
||||
... shuffle=False)
|
||||
>>>
|
||||
>>> # 2) Randomly select 10 samples from TEDLIUM_release2 dataset.
|
||||
>>> dataset = ds.TedliumDataset(dataset_dir=tedlium_dataset_dir, release=tedlium_dataset_release[1],
|
||||
... num_samples=10, shuffle=True)
|
||||
>>>
|
||||
>>> # 3) Get samples from TEDLIUM_release-3 dataset for shard 0 in a 2-way distributed training.
|
||||
>>> dataset = ds.TedliumDataset(dataset_dir=tedlium_dataset_dir, release=tedlium_dataset_release[2],
|
||||
... num_shards=2, shard_id=0)
|
||||
>>>
|
||||
>>> # In TEDLIUM dataset, each dictionary has keys : waveform, sample_rate, transcript, talk_id,
|
||||
>>> # speaker_id and identifier.
|
||||
|
||||
About TEDLIUM_release1 dataset:
|
||||
|
||||
The TED-LIUM corpus is English-language TED talks, with transcriptions, sampled at 16kHz.
|
||||
It contains about 118 hours of speech.
|
||||
|
||||
About TEDLIUM_release2 dataset:
|
||||
|
||||
This is the TED-LIUM corpus release 2, licensed under Creative Commons BY-NC-ND 3.0. All talks and text are
|
||||
property of TED Conferences LLC. The TED-LIUM corpus was made from audio talks and their transcriptions available
|
||||
on the TED website. We have prepared and filtered these data in order to train acoustic models to participate to
|
||||
the International Workshop on Spoken Language Translation 2011 (the LIUM English/French SLT system reached the
|
||||
first rank in the SLT task).
|
||||
|
||||
About TEDLIUM_release-3 dataset:
|
||||
|
||||
This is the TED-LIUM corpus release 3, licensed under Creative Commons BY-NC-ND 3.0. All talks and text are
|
||||
property of TED Conferences LLC. This new TED-LIUM release was made through a collaboration between the Ubiqus
|
||||
company and the LIUM (University of Le Mans, France).
|
||||
|
||||
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
|
||||
|
||||
The structure of TEDLIUM release2 is the same as TEDLIUM release1, only the data is different.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└──TEDLIUM_release1
|
||||
└── dev
|
||||
├── sph
|
||||
├── AlGore_2009.sph
|
||||
├── BarrySchwartz_2005G.sph
|
||||
├── stm
|
||||
├── AlGore_2009.stm
|
||||
├── BarrySchwartz_2005G.stm
|
||||
└── test
|
||||
├── sph
|
||||
├── AimeeMullins_2009P.sph
|
||||
├── BillGates_2010.sph
|
||||
├── stm
|
||||
├── AimeeMullins_2009P.stm
|
||||
├── BillGates_2010.stm
|
||||
└── train
|
||||
├── sph
|
||||
├── AaronHuey_2010X.sph
|
||||
├── AdamGrosser_2007.sph
|
||||
├── stm
|
||||
├── AaronHuey_2010X.stm
|
||||
├── AdamGrosser_2007.stm
|
||||
└── readme
|
||||
└── TEDLIUM.150k.dic
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└──TEDLIUM_release-3
|
||||
└── data
|
||||
├── ctl
|
||||
├── sph
|
||||
├── 911Mothers_2010W.sph
|
||||
├── AalaElKhani.sph
|
||||
├── stm
|
||||
├── 911Mothers_2010W.stm
|
||||
├── AalaElKhani.stm
|
||||
└── doc
|
||||
└── legacy
|
||||
└── LM
|
||||
└── speaker-adaptation
|
||||
└── readme
|
||||
└── TEDLIUM.150k.dic
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@article{
|
||||
title={TED-LIUM: an automatic speech recognition dedicated corpus},
|
||||
author={A. Rousseau, P. Deléglise, Y. Estève},
|
||||
journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)},
|
||||
year={May 2012},
|
||||
biburl={https://www.openslr.org/7/}
|
||||
}
|
||||
|
||||
@article{
|
||||
title={Enhancing the TED-LIUM Corpus with Selected Data for Language Modeling and More TED Talks},
|
||||
author={A. Rousseau, P. Deléglise, and Y. Estève},
|
||||
journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)},
|
||||
year={May 2014},
|
||||
biburl={https://www.openslr.org/19/}
|
||||
}
|
||||
|
||||
@article{
|
||||
title={TED-LIUM 3: twice as much data and corpus repartition for experiments on speaker adaptation},
|
||||
author={François Hernandez, Vincent Nguyen, Sahar Ghannay, Natalia Tomashenko, and Yannick Estève},
|
||||
journal={the 20th International Conference on Speech and Computer (SPECOM 2018)},
|
||||
year={September 2018},
|
||||
biburl={https://www.openslr.org/51/}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_tedlium_dataset
|
||||
def __init__(self, dataset_dir, release, usage=None, extensions=None, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None,
|
||||
shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.extensions = replace_none(extensions, ".sph")
|
||||
self.release = release
|
||||
self.usage = replace_none(usage, "all")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.TedliumNode(self.dataset_dir, self.release, self.usage, self.extensions, self.sampler)
|
||||
|
|
|
@ -1863,3 +1863,39 @@ def check_yes_no_dataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_tedlium_dataset(method):
|
||||
"""Wrapper method to check the parameters of TedliumDataset."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle']
|
||||
|
||||
release = param_dict.get('release')
|
||||
check_valid_str(release, ["release1", "release2", "release3"], "release")
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
if release in ["release1", "release2"]:
|
||||
check_valid_str(usage, ["train", "test", "dev", "all"], "usage")
|
||||
else:
|
||||
check_valid_str(usage, ["all"], "usage")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -40,6 +40,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_save.cc
|
||||
c_api_dataset_sbu_test.cc
|
||||
c_api_dataset_speech_commands_test.cc
|
||||
c_api_dataset_tedlium_test.cc
|
||||
c_api_dataset_textfile_test.cc
|
||||
c_api_dataset_tfrecord_test.cc
|
||||
c_api_dataset_usps_test.cc
|
||||
|
|
|
@ -0,0 +1,383 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: read some samples from all files according to different versions.
|
||||
/// Expectation: 4 * 2 samples.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDataset.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
|
||||
std::shared_ptr<Dataset> ds1 =
|
||||
Tedlium(folder_path12, "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
|
||||
std::shared_ptr<Dataset> ds3 =
|
||||
Tedlium(folder_path3, "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
|
||||
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
EXPECT_NE(iter1, nullptr);
|
||||
EXPECT_NE(iter3, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row1;
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row3;
|
||||
ASSERT_OK(iter1->GetNextRow(&row1));
|
||||
|
||||
EXPECT_NE(row1.find("waveform"), row1.end());
|
||||
EXPECT_NE(row1.find("sample_rate"), row1.end());
|
||||
EXPECT_NE(row1.find("transcript"), row1.end());
|
||||
EXPECT_NE(row1.find("talk_id"), row1.end());
|
||||
EXPECT_NE(row1.find("speaker_id"), row1.end());
|
||||
EXPECT_NE(row1.find("identifier"), row1.end());
|
||||
|
||||
ASSERT_OK(iter3->GetNextRow(&row3));
|
||||
|
||||
EXPECT_NE(row3.find("waveform"), row3.end());
|
||||
EXPECT_NE(row3.find("sample_rate"), row3.end());
|
||||
EXPECT_NE(row3.find("transcript"), row3.end());
|
||||
EXPECT_NE(row3.find("talk_id"), row3.end());
|
||||
EXPECT_NE(row3.find("speaker_id"), row3.end());
|
||||
EXPECT_NE(row3.find("identifier"), row3.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row1.size() != 0) {
|
||||
i++;
|
||||
auto audio = row1["waveform"];
|
||||
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
|
||||
ASSERT_OK(iter1->GetNextRow(&row1));
|
||||
}
|
||||
while (row3.size() != 0) {
|
||||
i++;
|
||||
auto audio = row3["waveform"];
|
||||
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
|
||||
ASSERT_OK(iter3->GetNextRow(&row3));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 4 * 2);
|
||||
|
||||
// Manually terminate the pipeline.
|
||||
iter1->Stop();
|
||||
iter3->Stop();
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: read some samples with pipeline from all files according to different versions.
|
||||
/// Expectation: 8 * 2 samples.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithPipeline.";
|
||||
|
||||
// Create two Tedlium Dataset.
|
||||
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
|
||||
std::shared_ptr<Dataset> ds11 =
|
||||
Tedlium(folder_path12, "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
|
||||
std::shared_ptr<Dataset> ds31 =
|
||||
Tedlium(folder_path3, "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
|
||||
std::shared_ptr<Dataset> ds12 =
|
||||
Tedlium(folder_path12, "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
|
||||
std::shared_ptr<Dataset> ds32 =
|
||||
Tedlium(folder_path3, "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4), nullptr);
|
||||
|
||||
EXPECT_NE(ds11, nullptr);
|
||||
EXPECT_NE(ds12, nullptr);
|
||||
EXPECT_NE(ds31, nullptr);
|
||||
EXPECT_NE(ds32, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds.
|
||||
int32_t repeat_num = 1;
|
||||
ds11 = ds11->Repeat(repeat_num);
|
||||
ds31 = ds31->Repeat(repeat_num);
|
||||
EXPECT_NE(ds11, nullptr);
|
||||
EXPECT_NE(ds31, nullptr);
|
||||
repeat_num = 1;
|
||||
ds12 = ds12->Repeat(repeat_num);
|
||||
ds32 = ds32->Repeat(repeat_num);
|
||||
EXPECT_NE(ds12, nullptr);
|
||||
EXPECT_NE(ds32, nullptr);
|
||||
|
||||
// Create two Project operation on ds.
|
||||
std::vector<std::string> column_project = {"waveform", "sample_rate", "transcript",
|
||||
"talk_id", "speaker_id", "identifier"};
|
||||
ds11 = ds11->Project(column_project);
|
||||
EXPECT_NE(ds11, nullptr);
|
||||
ds12 = ds12->Project(column_project);
|
||||
EXPECT_NE(ds12, nullptr);
|
||||
ds31 = ds31->Project(column_project);
|
||||
EXPECT_NE(ds31, nullptr);
|
||||
ds32 = ds32->Project(column_project);
|
||||
EXPECT_NE(ds32, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds.
|
||||
ds11 = ds11->Concat({ds12});
|
||||
ds31 = ds31->Concat({ds32});
|
||||
EXPECT_NE(ds11, nullptr);
|
||||
EXPECT_NE(ds31, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter1 = ds11->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter3 = ds31->CreateIterator();
|
||||
EXPECT_NE(iter1, nullptr);
|
||||
EXPECT_NE(iter3, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row1;
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row3;
|
||||
ASSERT_OK(iter1->GetNextRow(&row1));
|
||||
ASSERT_OK(iter3->GetNextRow(&row3));
|
||||
|
||||
EXPECT_NE(row1.find("waveform"), row1.end());
|
||||
EXPECT_NE(row1.find("sample_rate"), row1.end());
|
||||
EXPECT_NE(row1.find("transcript"), row1.end());
|
||||
EXPECT_NE(row1.find("talk_id"), row1.end());
|
||||
EXPECT_NE(row1.find("speaker_id"), row1.end());
|
||||
EXPECT_NE(row1.find("identifier"), row1.end());
|
||||
|
||||
EXPECT_NE(row3.find("waveform"), row3.end());
|
||||
EXPECT_NE(row3.find("sample_rate"), row3.end());
|
||||
EXPECT_NE(row3.find("transcript"), row3.end());
|
||||
EXPECT_NE(row3.find("talk_id"), row3.end());
|
||||
EXPECT_NE(row3.find("speaker_id"), row3.end());
|
||||
EXPECT_NE(row3.find("identifier"), row3.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row1.size() != 0) {
|
||||
i++;
|
||||
auto audio = row1["waveform"];
|
||||
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
|
||||
ASSERT_OK(iter1->GetNextRow(&row1));
|
||||
}
|
||||
while (row3.size() != 0) {
|
||||
i++;
|
||||
auto audio = row3["waveform"];
|
||||
MS_LOG(INFO) << "Tensor audio shape: " << audio.Shape();
|
||||
ASSERT_OK(iter3->GetNextRow(&row3));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 8 * 2);
|
||||
|
||||
// Manually terminate the pipeline.
|
||||
iter1->Stop();
|
||||
iter3->Stop();
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: read number of all samples from all files according to different versions.
|
||||
/// Expectation: TEDLIUM_release12 : 1 + 2 + 3
|
||||
/// TEDLIUM_release3 : 3 + 4
|
||||
TEST_F(MindDataTestPipeline, TestTedliumGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumGetDatasetSize.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
|
||||
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "release1", "all", ".sph");
|
||||
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "release3", "all", ".sph");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
EXPECT_EQ(ds1->GetDatasetSize(), 1 + 2 + 3);
|
||||
EXPECT_EQ(ds3->GetDatasetSize(), 3 + 4);
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: Includes tests for shape, type, size.
|
||||
/// Expectation: correct shape, type, size.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumGetters.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::shared_ptr<Dataset> ds = Tedlium(folder_path, "release1", "all", ".sph");
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"waveform", "sample_rate", "transcript",
|
||||
"talk_id", "speaker_id", "identifier"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 6);
|
||||
EXPECT_EQ(types[0].ToString(), "float32");
|
||||
EXPECT_EQ(types[1].ToString(), "int32");
|
||||
EXPECT_EQ(types[2].ToString(), "string");
|
||||
EXPECT_EQ(types[3].ToString(), "string");
|
||||
EXPECT_EQ(types[4].ToString(), "string");
|
||||
EXPECT_EQ(types[5].ToString(), "string");
|
||||
|
||||
EXPECT_EQ(shapes.size(), 6);
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[2].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[3].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[4].ToString(), "<>");
|
||||
EXPECT_EQ(shapes[5].ToString(), "<>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 1 + 2 + 3);
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: test with invalid release.
|
||||
/// Expectation: unable to read in data.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithInvalidReleaseFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithInvalidReleaseFail.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
|
||||
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "", "all", ".sph");
|
||||
std::shared_ptr<Dataset> ds2 = Tedlium(folder_path12, "RELEASE2", "all", ".sph");
|
||||
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "2", "all", ".sph");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid Tedlium input, "", "RELEASE2" and "2" are not a valid release.
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: test with invalid path.
|
||||
/// Expectation: unable to read in data.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDatasetFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetFail.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::shared_ptr<Dataset> ds1 = Tedlium("", "release1", "all", ".sph", std::make_shared<RandomSampler>(false, 4));
|
||||
std::shared_ptr<Dataset> ds2 =
|
||||
Tedlium("validation", "release2", "all", ".sph", std::make_shared<RandomSampler>(false, 4));
|
||||
std::shared_ptr<Dataset> ds3 = Tedlium("2", "release3", "all", ".sph", std::make_shared<RandomSampler>(false, 4));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid Tedlium input, "", "validation" and "2" are not a valid path.
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: test with invalid usage.
|
||||
/// Expectation: unable to read in data.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithInvalidUsageFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithInvalidUsageFail.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
|
||||
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "release1", "", ".sph");
|
||||
std::shared_ptr<Dataset> ds2 = Tedlium(folder_path12, "release2", "DEV", ".sph");
|
||||
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "release3", "2", ".sph");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid Tedlium input, "", "DEV" and "2" are not a valid usage.
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: test with invalid extensions.
|
||||
/// Expectation: unable to read in data.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithInvalidExtensionsFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithInvalidExtensionsFail.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path12 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::string folder_path3 = datasets_root_path_ + "/testTedliumData/TEDLIUM_release3";
|
||||
std::shared_ptr<Dataset> ds1 = Tedlium(folder_path12, "release1", "all", "sph");
|
||||
std::shared_ptr<Dataset> ds2 = Tedlium(folder_path12, "release2", "all", ".SPH");
|
||||
std::shared_ptr<Dataset> ds3 = Tedlium(folder_path3, "release3", "all", ".stm");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_NE(ds3, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
|
||||
// Expect failure: invalid Tedlium input, "sph", ".SPH", ".stm" are not a valid extensions.
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
EXPECT_EQ(iter3, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: TedliumDataset.
|
||||
/// Description: test with null sampler.
|
||||
/// Expectation: unable to read in data.
|
||||
TEST_F(MindDataTestPipeline, TestTedliumDatasetWithNullSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTedliumDatasetWithNullSamplerFail.";
|
||||
|
||||
// Create a Tedlium Dataset.
|
||||
std::string folder_path = datasets_root_path_ + "/testTedliumData/TEDLIUM_release1";
|
||||
std::shared_ptr<Dataset> ds = Tedlium(folder_path, "release1", "all", ".sph", nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid Tedlium input, sampler cannot be nullptr.
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1 @@
|
|||
test1 1 test1 0.00 0.03 <o,f0,female> this is record 1 of test1.
|
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
test2 1 test2 0.00 0.02 <o,f0,female> this is record 1 of test2.
|
||||
test2 1 test2 0.02 0.09 <o,f0,female> this is record 2 of test2.
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1 @@
|
|||
test1 1 test1 0.00 0.03 <o,f0,female> this is record 1 of test1.
|
|
@ -0,0 +1,2 @@
|
|||
test2 1 test2 0.00 0.02 <o,f0,female> this is record 1 of test2.
|
||||
test2 1 test2 0.02 0.09 <o,f0,female> this is record 2 of test2.
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
test3 1 test3 0.00 0.01 <o,f0,female> this is record 1 of test3.
|
||||
test3 1 test3 0.02 0.07 <o,f0,female> this is record 1 of test3.
|
||||
test3 1 test3 0.07 0.09 <o,f0,female> this is record 1 of test3.
|
|
@ -0,0 +1,4 @@
|
|||
test4 1 test4 0.00 0.01 <o,f0,female> this is record 1 of test4.
|
||||
test4 1 test4 0.02 0.03 <o,f0,female> this is record 2 of test4.
|
||||
test4 1 test4 0.05 0.07 <o,f0,female> this is record 3 of test4.
|
||||
test4 1 test4 0.07 0.09 <o,f0,female> this is record 4 of test4.
|
|
@ -0,0 +1 @@
|
|||
test1 1 test1 0.00 0.03 <o,f0,female> this is record 1 of test1.
|
Binary file not shown.
|
@ -0,0 +1,405 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
|
||||
DATA_DIR_TEDLIUM_RELEASE12 = "../data/dataset/testTedliumData/TEDLIUM_release1"
|
||||
DATA_DIR_TEDLIUM_RELEASE3 = "../data/dataset/testTedliumData/TEDLIUM_release3"
|
||||
RELEASE1 = "release1"
|
||||
RELEASE2 = "release2"
|
||||
RELEASE3 = "release3"
|
||||
|
||||
NO_SPH_DIR_TEDLIUM12 = "../data/dataset/testTedliumData/else"
|
||||
|
||||
|
||||
def test_tedlium_basic():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: use different data to test the functions of different versions
|
||||
Expectation: num_samples
|
||||
set 1 2 4
|
||||
get 1 2 4
|
||||
num_parallel_workers
|
||||
set 1 2 4(num_samples=4)
|
||||
get 4 4 4
|
||||
num repeat
|
||||
set 3(num_samples=5)
|
||||
get 15
|
||||
"""
|
||||
# case1 test num_samples
|
||||
data11 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1)
|
||||
data12 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_samples=2)
|
||||
data13 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=4)
|
||||
num_iter11 = 0
|
||||
num_iter12 = 0
|
||||
num_iter13 = 0
|
||||
for _ in data11.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter11 += 1
|
||||
|
||||
for _ in data12.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter12 += 1
|
||||
|
||||
for _ in data13.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter13 += 1
|
||||
|
||||
assert num_iter11 == 1
|
||||
assert num_iter12 == 2
|
||||
assert num_iter13 == 4
|
||||
|
||||
# case2 test num_parallel_workers
|
||||
data21 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=4, num_parallel_workers=1)
|
||||
data22 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_samples=4, num_parallel_workers=2)
|
||||
data23 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=4, num_parallel_workers=4)
|
||||
num_iter21 = 0
|
||||
num_iter22 = 0
|
||||
num_iter23 = 0
|
||||
for _ in data21.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter21 += 1
|
||||
|
||||
for _ in data22.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter22 += 1
|
||||
|
||||
for _ in data23.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter23 += 1
|
||||
|
||||
assert num_iter21 == 4
|
||||
assert num_iter22 == 4
|
||||
assert num_iter23 == 4
|
||||
|
||||
# case3 test repeat
|
||||
data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=5)
|
||||
data3 = data3.repeat(3)
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter3 += 1
|
||||
|
||||
assert num_iter3 == 15
|
||||
|
||||
|
||||
def test_tedlium_content_check():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: Check content of the first sample
|
||||
Expectation: correct content
|
||||
"""
|
||||
data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1, shuffle=False)
|
||||
data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=1, shuffle=False)
|
||||
num_iter1 = 0
|
||||
num_iter3 = 0
|
||||
for d in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
waveform = d["waveform"]
|
||||
sample_rate = d["sample_rate"]
|
||||
transcript = d["transcript"]
|
||||
talk_id = d["talk_id"]
|
||||
speaker_id = d["speaker_id"]
|
||||
identifier = d["identifier"]
|
||||
assert waveform.dtype == np.float32
|
||||
assert waveform.shape == (1, 480)
|
||||
assert sample_rate == 16000
|
||||
assert sample_rate.dtype == np.int32
|
||||
assert talk_id.item().decode("utf8") == "test1"
|
||||
assert speaker_id.item().decode("utf8") == "test1"
|
||||
assert transcript.item().decode("utf8") == "this is record 1 of test1."
|
||||
assert identifier.item().decode("utf8") == "<o,f0,female>"
|
||||
num_iter1 += 1
|
||||
for d in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
waveform = d["waveform"]
|
||||
sample_rate = d["sample_rate"]
|
||||
transcript = d["transcript"]
|
||||
talk_id = d["talk_id"]
|
||||
speaker_id = d["speaker_id"]
|
||||
identifier = d["identifier"]
|
||||
assert waveform.dtype == np.float32
|
||||
assert waveform.shape == (1, 160)
|
||||
assert sample_rate == 16000
|
||||
assert sample_rate.dtype == np.int32
|
||||
assert talk_id.item().decode("utf8") == "test3"
|
||||
assert speaker_id.item().decode("utf8") == "test3"
|
||||
assert transcript.item().decode("utf8") == "this is record 1 of test3."
|
||||
assert identifier.item().decode("utf8") == "<o,f0,female>"
|
||||
num_iter3 += 1
|
||||
assert num_iter1 == 1
|
||||
assert num_iter3 == 1
|
||||
|
||||
|
||||
def test_tedlium_exceptions():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: send error when error occur
|
||||
Expectation: send error
|
||||
"""
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_shards=2, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, shuffle=False, num_parallel_workers=256)
|
||||
|
||||
error_msg_7 = "Invalid data, no valid data matching the dataset API TedliumDataset"
|
||||
with pytest.raises(RuntimeError, match=error_msg_7):
|
||||
ds1 = ds.TedliumDataset(NO_SPH_DIR_TEDLIUM12, RELEASE1, "train")
|
||||
for _ in ds1.__iter__():
|
||||
pass
|
||||
|
||||
|
||||
def test_tedlium_exception_file_path():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: error test
|
||||
Expectation: throw error
|
||||
"""
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1)
|
||||
data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1)
|
||||
data = data.map(operations=exception_func, input_columns=["sample_rate"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2)
|
||||
data = data.map(operations=exception_func, input_columns=["transcript"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2)
|
||||
data = data.map(operations=exception_func, input_columns=["talk_id"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3)
|
||||
data = data.map(operations=exception_func, input_columns=["speaker_id"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3)
|
||||
data = data.map(operations=exception_func, input_columns=["identifier"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
|
||||
def test_tedlium_extensions():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: test extensions of tedlium
|
||||
Expectation: extensions
|
||||
set invalid data
|
||||
get throw error
|
||||
"""
|
||||
try:
|
||||
data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, "train", "invalid")
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "is not supported." in str(e)
|
||||
|
||||
|
||||
def test_tedlium_release():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: test release of tedlium
|
||||
Expectation: release
|
||||
set invalid data
|
||||
get throw error
|
||||
"""
|
||||
def test_config(release):
|
||||
try:
|
||||
ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, release)
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return None
|
||||
|
||||
# test the release
|
||||
assert "release is not within the valid set of ['release1', 'release2', 'release3']" in test_config("invalid")
|
||||
assert "Argument release with value None is not of type [<class 'str'>]" in test_config(None)
|
||||
assert "Argument release with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||
|
||||
|
||||
def test_tedlium_sequential_sampler():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: test tedlium sequential sampler
|
||||
Expectation: correct data
|
||||
"""
|
||||
num_samples = 3
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data21 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, sampler=sampler)
|
||||
data22 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, shuffle=False, num_samples=num_samples)
|
||||
num_iter2 = 0
|
||||
for item1, item2 in zip(data21.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
data22.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_equal(item1["waveform"], item2["waveform"])
|
||||
num_iter2 += 1
|
||||
|
||||
assert num_iter2 == num_samples
|
||||
|
||||
|
||||
def test_tedlium_sampler_get_dataset_size():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: test TedliumDataset with SequentialSampler and get_dataset_size
|
||||
Expectation: num_samples
|
||||
set 5
|
||||
get 5
|
||||
"""
|
||||
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
|
||||
data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, sampler=sampler)
|
||||
num_iter3 = 0
|
||||
ds_sz3 = data3.get_dataset_size()
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter3 += 1
|
||||
|
||||
assert ds_sz3 == num_iter3 == 5
|
||||
|
||||
|
||||
def test_tedlium_usage():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: test usage of tedlium
|
||||
Expectation: usage
|
||||
set valid data invalid data
|
||||
get correct data throw error
|
||||
"""
|
||||
def test_config_tedlium12(usage):
|
||||
|
||||
try:
|
||||
data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, usage=usage)
|
||||
data2 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, usage=usage)
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
# test the usage of TEDLIUM
|
||||
assert test_config_tedlium12("dev") == 1 + 1
|
||||
assert test_config_tedlium12("test") == 2 + 2
|
||||
assert test_config_tedlium12("train") == 3 + 3
|
||||
assert test_config_tedlium12("all") == 1 + 1 + 2 + 2 + 3 + 3
|
||||
assert "usage is not within the valid set of ['train', 'test', 'dev', 'all']" in test_config_tedlium12("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config_tedlium12(["list"])
|
||||
|
||||
|
||||
def test_tedlium_with_chained_sampler_get_dataset_size():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: test TedliumDataset with RandomSampler chained with a SequentialSampler and get_dataset_size
|
||||
Expectation: num_samples
|
||||
set 2
|
||||
get 2
|
||||
"""
|
||||
sampler = ds.SequentialSampler(start_index=0, num_samples=2)
|
||||
child_sampler = ds.RandomSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, sampler=sampler)
|
||||
num_iter1 = 0
|
||||
ds_sz1 = data1.get_dataset_size()
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter1 += 1
|
||||
|
||||
assert ds_sz1 == num_iter1 == 2
|
||||
|
||||
|
||||
def test_tedlium_pipeline():
|
||||
"""
|
||||
Feature: TedliumDataset
|
||||
Description: Read a sample
|
||||
Expectation: The amount of each function are equal
|
||||
"""
|
||||
# Original waveform
|
||||
dataset = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1)
|
||||
band_biquad_op = audio.BandBiquad(8000, 200.0)
|
||||
# Filtered waveform by bandbiquad
|
||||
dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
|
||||
i = 0
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
i += 1
|
||||
assert i == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tedlium_basic()
|
||||
test_tedlium_content_check()
|
||||
test_tedlium_exceptions()
|
||||
test_tedlium_exception_file_path()
|
||||
test_tedlium_extensions()
|
||||
test_tedlium_release()
|
||||
test_tedlium_sequential_sampler()
|
||||
test_tedlium_sampler_get_dataset_size()
|
||||
test_tedlium_usage()
|
||||
test_tedlium_with_chained_sampler_get_dataset_size()
|
||||
test_tedlium_pipeline()
|
Loading…
Reference in New Issue