!22736 [assistant][ops] Add new dataset operator LibriTTSDataset
Merge pull request !22736 from TR-nbu/LibriTTSDataset
This commit is contained in:
commit
85ac17a8c6
|
@ -99,6 +99,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
|
@ -1393,6 +1394,28 @@ KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::ve
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
LibriTTSDataset::LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<LibriTTSNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
LibriTTSDataset::LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<LibriTTSNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
LibriTTSDataset::LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<LibriTTSNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
|
|
|
@ -50,6 +50,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
|
@ -415,6 +416,18 @@ PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(LibriTTSNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<LibriTTSNode, DatasetNode, std::shared_ptr<LibriTTSNode>>(*m, "LibriTTSNode",
|
||||
"to create a LibriTTSNode")
|
||||
.def(
|
||||
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||
std::shared_ptr<LibriTTSNode> libri_tts =
|
||||
std::make_shared<LibriTTSNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(libri_tts->ValidateParams());
|
||||
return libri_tts;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<LJSpeechNode, DatasetNode, std::shared_ptr<LJSpeechNode>>(*m, "LJSpeechNode",
|
||||
"to create a LJSpeechNode")
|
||||
|
|
|
@ -28,6 +28,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
iwslt_op.cc
|
||||
io_block.cc
|
||||
kmnist_op.cc
|
||||
libri_tts_op.cc
|
||||
lj_speech_op.cc
|
||||
mappable_leaf_op.cc
|
||||
mnist_op.cc
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/libri_tts_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const int32_t label_file_suffix_len = 10;
|
||||
const char label_file_suffix[] = ".trans.tsv";
|
||||
const char audio_file_suffix[] = ".wav";
|
||||
const std::vector<std::string> usage_list = {"dev-clean", "dev-other", "test-clean", "test-other",
|
||||
"train-clean-100", "train-clean-360", "train-other-500"};
|
||||
|
||||
LibriTTSOp::LibriTTSOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
Status LibriTTSOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
LibriTTSLabelTuple audio_tuple = audio_label_tuples_[row_id];
|
||||
const uint32_t rate = 24000;
|
||||
std::shared_ptr<Tensor> waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id;
|
||||
Path dir(real_path_);
|
||||
std::string file_name = audio_tuple.utterance_id + audio_file_suffix;
|
||||
Path full_dir = dir / audio_tuple.usage / std::to_string(audio_tuple.speaker_id) /
|
||||
std::to_string(audio_tuple.chapter_id) / file_name;
|
||||
RETURN_IF_NOT_OK(ReadAudio(full_dir.ToString(), &waveform));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(rate, &sample_rate));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.original_text, &original_text));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.normalized_text, &normalized_text));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.speaker_id, &speaker_id));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.chapter_id, &chapter_id));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.utterance_id, &utterance_id));
|
||||
(*trow) = TensorRow(
|
||||
row_id, {std::move(waveform), std::move(sample_rate), std::move(original_text), std::move(normalized_text),
|
||||
std::move(speaker_id), std::move(chapter_id), std::move(utterance_id)});
|
||||
std::string label_path = audio_tuple.label_path;
|
||||
trow->setPath({full_dir.ToString(), full_dir.ToString(), label_path, label_path, label_path, label_path, label_path});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void LibriTTSOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\n";
|
||||
} else {
|
||||
ParallelOp::Print(out, show_all);
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nLibriTTS directory: " << dataset_dir_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status LibriTTSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar_rate = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate)));
|
||||
TensorShape scalar_original_text = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("original_text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_original_text)));
|
||||
TensorShape scalar_normalized_text = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_text", DataType(DataType::DE_STRING),
|
||||
TensorImpl::kFlexible, 0, &scalar_normalized_text)));
|
||||
TensorShape scalar_speaker_id = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("speaker_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_speaker_id)));
|
||||
TensorShape scalar_chapter_id = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("chapter_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_chapter_id)));
|
||||
TensorShape scalar_utterance_id = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id)));
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
auto op =
|
||||
std::make_shared<LibriTTSOp>(dir, usage, num_workers, op_connect_size, std::move(schema), std::move(sampler));
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
*count = op->audio_label_tuples_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSOp::ComputeColMap() {
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSOp::ReadAudio(const std::string &audio_dir, std::shared_ptr<Tensor> *waveform) {
|
||||
RETURN_UNEXPECTED_IF_NULL(waveform);
|
||||
const int32_t kWavFileSampleRate = 24000;
|
||||
int32_t sample_rate = 0;
|
||||
std::vector<float> waveform_vec;
|
||||
RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
sample_rate == kWavFileSampleRate,
|
||||
"Invalid file, sampling rate of LibriTTS wav file must be 24000, file path: " + audio_dir);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform));
|
||||
RETURN_IF_NOT_OK((*waveform)->ExpandDim(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSOp::PrepareData() {
|
||||
auto realpath = FileUtils::GetRealPath(dataset_dir_.data());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file path, LibriTTS dataset dir: " << dataset_dir_ << " does not exist.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, LibriTTS dataset dir: " + dataset_dir_ + " does not exist.");
|
||||
}
|
||||
real_path_ = realpath.value();
|
||||
Path dir(real_path_);
|
||||
if (usage_ != "all") {
|
||||
Path full_dir = dir / usage_;
|
||||
cur_usage_ = usage_;
|
||||
RETURN_IF_NOT_OK(GetPaths(&full_dir));
|
||||
RETURN_IF_NOT_OK(GetLabels());
|
||||
} else {
|
||||
for (std::string usage_iter : usage_list) {
|
||||
cur_usage_ = usage_iter;
|
||||
Path full_dir = dir / cur_usage_;
|
||||
RETURN_IF_NOT_OK(GetPaths(&full_dir));
|
||||
RETURN_IF_NOT_OK(GetLabels());
|
||||
}
|
||||
}
|
||||
num_rows_ = audio_label_tuples_.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0,
|
||||
"Invalid data, no valid data matching the dataset API LibriTTSDataset. "
|
||||
"Please check dataset API or file path: " +
|
||||
dataset_dir_ + ".");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSOp::GetPaths(Path *dir) {
|
||||
RETURN_UNEXPECTED_IF_NULL(dir);
|
||||
auto iter = Path::DirIterator::OpenDirectory(dir);
|
||||
if (iter == nullptr) {
|
||||
MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << dir->ToString() << ".";
|
||||
} else {
|
||||
while (iter->HasNext()) {
|
||||
Path sub_dir = iter->Next();
|
||||
if (sub_dir.IsDirectory()) {
|
||||
RETURN_IF_NOT_OK(GetPaths(&sub_dir));
|
||||
} else {
|
||||
Path file_path = sub_dir;
|
||||
std::string file_name = file_path.Basename();
|
||||
int32_t length = file_name.size();
|
||||
if (length > label_file_suffix_len && file_name.substr(length - label_file_suffix_len) == label_file_suffix) {
|
||||
label_files_.push_back(sub_dir.ToString());
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSOp::GetLabels() {
|
||||
std::string utterance_id_body = "";
|
||||
std::string original_text_body = "";
|
||||
std::string normalized_text_body = "";
|
||||
const uint32_t base = 10;
|
||||
const uint32_t ascii_zero = 48;
|
||||
const size_t underline_exact = 3;
|
||||
for (std::string label_file : label_files_) {
|
||||
std::ifstream label_reader(label_file);
|
||||
while (getline(label_reader, utterance_id_body, '\t')) {
|
||||
getline(label_reader, original_text_body, '\t');
|
||||
getline(label_reader, normalized_text_body, '\n');
|
||||
uint32_t speaker_id = 0;
|
||||
uint32_t chapter_id = 0;
|
||||
size_t underline_num = 0;
|
||||
size_t underline_inx[4] = {0};
|
||||
for (size_t i = 0; i < utterance_id_body.size() && underline_num <= underline_exact; i++) {
|
||||
if (utterance_id_body[i] == '_') {
|
||||
underline_inx[underline_num++] = i;
|
||||
}
|
||||
}
|
||||
if (underline_num != underline_exact) {
|
||||
label_reader.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, the file may not be a LibriTTS dataset file: " + label_file);
|
||||
}
|
||||
for (size_t i = 0; i < underline_inx[0]; i++) {
|
||||
speaker_id = speaker_id * base + utterance_id_body[i] - ascii_zero;
|
||||
}
|
||||
for (size_t i = underline_inx[0] + 1; i < underline_inx[1]; i++) {
|
||||
chapter_id = chapter_id * base + utterance_id_body[i] - ascii_zero;
|
||||
}
|
||||
audio_label_tuples_.push_back(
|
||||
{cur_usage_, utterance_id_body, original_text_body, normalized_text_body, speaker_id, chapter_id, label_file});
|
||||
}
|
||||
label_reader.close();
|
||||
}
|
||||
label_files_.clear();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset.
|
||||
} // namespace mindspore.
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
struct LibriTTSLabelTuple {
|
||||
std::string usage;
|
||||
std::string utterance_id;
|
||||
std::string original_text;
|
||||
std::string normalized_text;
|
||||
uint32_t speaker_id;
|
||||
uint32_t chapter_id;
|
||||
std::string label_path;
|
||||
};
|
||||
|
||||
class LibriTTSOp : public MappableLeafOp {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] dataset_dir Dir directory of LibriTTS.
|
||||
/// \param[in] usage usage of this dataset, can be "dev-clean", "dev-other", "test-clean", "test-other",
|
||||
/// "train-clean-100", "train-clean-360", "train-other-500", or "all".
|
||||
/// \param[in] num_workers Number of workers reading audios in parallel.
|
||||
/// \param[in] queue_size Connector queue size.
|
||||
/// \param[in] data_schema The schema of the LibriTTS dataset.
|
||||
/// \param[in] sampler Sampler tells LibriSpeechOp what to read.
|
||||
LibriTTSOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// \brief Destructor.
|
||||
~LibriTTSOp() = default;
|
||||
|
||||
/// \brief A print method typically used for debugging.
|
||||
/// \param[out] out Output stream.
|
||||
/// \param[in] show_all Whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Function to count the number of samples in the LibriTTS dataset.
|
||||
/// \param[in] dir Path to the LibriTTS directory.
|
||||
/// \param[in] usage Select the data set section.
|
||||
/// \param[out] count Output arg that will hold the minimum of the actual dataset size and numSamples.
|
||||
/// \return Status The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
|
||||
|
||||
/// \brief Op name getter.
|
||||
/// \return Name of the current Op.
|
||||
std::string Name() const override { return "LibriTTSOp"; }
|
||||
|
||||
private:
|
||||
/// \brief Load a tensor row according to a pair.
|
||||
/// \param[in] row_id Id for this tensor row.
|
||||
/// \param[out] row Audio & label read into this tensor row.
|
||||
/// \return Status The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
/// \brief Read all paths in the directory.
|
||||
/// \param[in] dir File path to be traversed.
|
||||
/// \return Status The status code returned.
|
||||
Status GetPaths(Path *dir);
|
||||
|
||||
/// \brief Read all label files.
|
||||
/// \return Status The status code returned.
|
||||
Status GetLabels();
|
||||
|
||||
/// \brief Parse a single wav file.
|
||||
/// \param[in] audio_dir Audio file path.
|
||||
/// \param[out] waveform The output waveform tensor.
|
||||
/// \return Status The status code returned.
|
||||
Status ReadAudio(const std::string &audio_dir, std::shared_ptr<Tensor> *waveform);
|
||||
|
||||
/// \brief Prepare all data in the directory.
|
||||
/// \return Status The status code returned.
|
||||
Status PrepareData();
|
||||
|
||||
/// \brief Private function for computing the assignment of the column name map.
|
||||
/// \return Status The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
const std::string usage_;
|
||||
std::string cur_usage_;
|
||||
std::string real_path_;
|
||||
std::string dataset_dir_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::vector<LibriTTSLabelTuple> audio_label_tuples_;
|
||||
std::vector<std::string> label_files_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_
|
|
@ -103,6 +103,7 @@ constexpr char kIMDBNode[] = "IMDBDataset";
|
|||
constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset";
|
||||
constexpr char kIWSLT2017Node[] = "IWSLT2017Dataset";
|
||||
constexpr char kKMnistNode[] = "KMnistDataset";
|
||||
constexpr char kLibriTTSNode[] = "LibriTTSDataset";
|
||||
constexpr char kLJSpeechNode[] = "LJSpeechDataset";
|
||||
constexpr char kManifestNode[] = "ManifestDataset";
|
||||
constexpr char kMindDataNode[] = "MindDataDataset";
|
||||
|
|
|
@ -29,6 +29,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
iwslt2016_node.cc
|
||||
iwslt2017_node.cc
|
||||
kmnist_node.cc
|
||||
libri_tts_node.cc
|
||||
lj_speech_node.cc
|
||||
manifest_node.cc
|
||||
minddata_node.cc
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/libri_tts_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
LibriTTSNode::LibriTTSNode(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
void LibriTTSNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
std::shared_ptr<DatasetNode> LibriTTSNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<LibriTTSNode>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
Status LibriTTSNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("LibriTTSDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("LibriTTSDataset", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("LibriTTSDataset", usage_,
|
||||
{"dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100",
|
||||
"train-clean-360", "train-other-500", "all"}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(LibriTTSOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar_rate = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate)));
|
||||
TensorShape scalar_original_text = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("original_text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_original_text)));
|
||||
TensorShape scalar_normalized_text = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_text", DataType(DataType::DE_STRING),
|
||||
TensorImpl::kFlexible, 0, &scalar_normalized_text)));
|
||||
TensorShape scalar_speaker_id = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("speaker_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_speaker_id)));
|
||||
TensorShape scalar_chapter_id = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("chapter_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_chapter_id)));
|
||||
TensorShape scalar_utterance_id = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(
|
||||
ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
auto op = std::make_shared<LibriTTSOp>(dataset_dir_, usage_, num_workers_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LibriTTSNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
args["dataset_dir"] = dataset_dir_;
|
||||
args["usage"] = usage_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class LibriTTSNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
LibriTTSNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~LibriTTSNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return kLibriTTSNode; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param out The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id The shard ID within num_shards.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &usage() const { return usage_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
/// \param[in] sampler Tells LibriTTSOp what to read.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_
|
|
@ -3230,6 +3230,103 @@ inline std::shared_ptr<KMnistDataset> MS_API KMnist(const std::string &dataset_d
|
|||
return std::make_shared<KMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class LibriTTSDataset
|
||||
/// \brief A source dataset for reading and parsing LibriTTSDataset dataset.
|
||||
class MS_API LibriTTSDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of LibriTTSDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of LibriTTS, can be "dev-clean", "dev-other", "test-clean",
|
||||
/// "test-other", "train-clean-100", "train-clean-360", "train-other-500" or "all".
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of LibriTTSDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of LibriTTS, can be "dev-clean", "dev-other", "test-clean",
|
||||
/// "test-other", "train-clean-100", "train-clean-360", "train-other-500" or "all".
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of LibriTTSDataset.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of LibriTTS, can be "dev-clean", "dev-other", "test-clean",
|
||||
/// "test-other", "train-clean-100", "train-clean-360", "train-other-500" or "all".
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
LibriTTSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor of LibriTTSDataset.
|
||||
~LibriTTSDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a LibriTTSDataset.
|
||||
/// \note The generated dataset has seven columns ['waveform', 'sample_rate', 'original_text', 'normalized_text',
|
||||
/// 'speaker_id', 'chapter_id', 'utterance_id'].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of LibriTTS, can be "dev-clean", "dev-other", "test-clean", "test-other",
|
||||
/// "train-clean-100", "train-clean-360", "train-other-500", or "all" (default = "all").
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the LibriTTSDataset.
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define dataset path and LibriTTS object */
|
||||
/// std::string folder_path = "/path/to/libri_tts_dataset_directory";
|
||||
/// std::shared_ptr<Dataset> ds = LibriTTS(folder_path);
|
||||
///
|
||||
/// /* Create iterator to read dataset */
|
||||
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
/// iter->GetNextRow(&row);
|
||||
///
|
||||
/// /* Note: In LibriTTS dataset, each data dictionary has seven columns ["waveform", "sample_rate",
|
||||
/// "original_text", "normalized_text", "speaker_id", "chapter_id", "utterance_id"].*/
|
||||
/// auto waveform = row["waveform"];
|
||||
/// \endcode
|
||||
inline std::shared_ptr<LibriTTSDataset> MS_API
|
||||
LibriTTS(const std::string &dataset_dir, const std::string &usage = "all",
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<LibriTTSDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a LibriTTSDataset.
|
||||
/// \note The generated dataset has seven columns ['waveform', 'sample_rate', 'original_text', 'normalized_text',
|
||||
/// 'speaker_id', 'chapter_id', 'utterance_id'].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of LibriTTS, can be "dev-clean", "dev-other", "test-clean", "test-other",
|
||||
/// "train-clean-100", "train-clean-360", "train-other-500", or "all".
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the LibriTTSDataset.
|
||||
inline std::shared_ptr<LibriTTSDataset> MS_API LibriTTS(const std::string &dataset_dir, const std::string &usage,
|
||||
const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<LibriTTSDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a LibriTTSDataset.
|
||||
/// \note The generated dataset has seven columns ['waveform', 'sample_rate', 'original_text', 'normalized_text',
|
||||
/// 'speaker_id', 'chapter_id', 'utterance_id'].
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] usage Part of dataset of LibriTTS, can be "dev-clean", "dev-other", "test-clean", "test-other",
|
||||
/// "train-clean-100", "train-clean-360", "train-other-500", or "all".
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||
/// \return Shared pointer to the LibriTTSDataset.
|
||||
inline std::shared_ptr<LibriTTSDataset> MS_API LibriTTS(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::reference_wrapper<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<LibriTTSDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
/// \class LJSpeechDataset
|
||||
/// \brief A source dataset for reading and parsing LJSpeech dataset.
|
||||
class MS_API LJSpeechDataset : public Dataset {
|
||||
|
|
|
@ -50,6 +50,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class ImageFolderDataset;
|
||||
friend class IMDBDataset;
|
||||
friend class KMnistDataset;
|
||||
friend class LibriTTSDataset;
|
||||
friend class LJSpeechDataset;
|
||||
friend class ManifestDataset;
|
||||
friend class MindDataDataset;
|
||||
|
|
|
@ -84,6 +84,7 @@ __all__ = ["Caltech101Dataset", # Vision
|
|||
"YelpReviewDataset", # Text
|
||||
"CMUArcticDataset", # Audio
|
||||
"GTZANDataset", # Audio
|
||||
"LibriTTSDataset", # Audio
|
||||
"LJSpeechDataset", # Audio
|
||||
"SpeechCommandsDataset", # Audio
|
||||
"TedliumDataset", # Audio
|
||||
|
|
|
@ -26,8 +26,8 @@ After declaring the dataset object, you can further apply dataset operations
|
|||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .datasets import AudioBaseDataset, MappableDataset
|
||||
from .validators import check_cmu_arctic_dataset, check_gtzan_dataset, check_lj_speech_dataset, check_speech_commands_dataset, \
|
||||
check_tedlium_dataset, check_yes_no_dataset
|
||||
from .validators import check_cmu_arctic_dataset, check_gtzan_dataset, check_libri_tts_dataset, check_lj_speech_dataset, \
|
||||
check_speech_commands_dataset, check_tedlium_dataset, check_yes_no_dataset
|
||||
|
||||
from ..core.validator_helpers import replace_none
|
||||
|
||||
|
@ -299,6 +299,156 @@ class GTZANDataset(MappableDataset, AudioBaseDataset):
|
|||
return cde.GTZANNode(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class LibriTTSDataset(MappableDataset, AudioBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses the LibriTTS dataset.
|
||||
|
||||
The generated dataset has seven columns :py:obj:`['waveform', 'sample_rate', 'original_text', 'normalized_text',
|
||||
'speaker_id', 'chapter_id', 'utterance_id']`.
|
||||
The tensor of column :py:obj:`waveform` is of the float32 type.
|
||||
The tensor of column :py:obj:`sample_rate` is of a scalar of uint32 type.
|
||||
The tensor of column :py:obj:`original_text` is of a scalar of string type.
|
||||
The tensor of column :py:obj:`normalized_text` is of a scalar of string type.
|
||||
The tensor of column :py:obj:`speaker_id` is of a scalar of uint32 type.
|
||||
The tensor of column :py:obj:`chapter_id` is of a scalar of uint32 type.
|
||||
The tensor of column :py:obj:`utterance_id` is of a scalar of string type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
usage (str, optional): Part of this dataset, can be ""dev-clean", "dev-other", "test-clean", "test-other",
|
||||
"train-clean-100", "train-clean-360", "train-other-500", or "all" (default=None, equal "all").
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all audio).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If source raises an exception during execution.
|
||||
RuntimeError: If dataset_dir does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- LibriTTS dataset doesn't support PKSampler.
|
||||
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> libri_tts_dataset_dir = "/path/to/libri_tts_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Read 500 samples (audio files) in libri_tts_dataset_directory
|
||||
>>> dataset = ds.LibriTTSDataset(libri_tts_dataset_dir, usage="train-clean-100", num_samples=500)
|
||||
>>>
|
||||
>>> # 2) Read all samples (audio files) in libri_tts_dataset_directory
|
||||
>>> dataset = ds.LibriTTSDataset(libri_tts_dataset_dir)
|
||||
|
||||
About LibriTTS dataset:
|
||||
|
||||
LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz
|
||||
sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members.
|
||||
The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio
|
||||
files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus.
|
||||
|
||||
You can construct the following directory structure from LibriTTS dataset and read by MindSpore's API.
|
||||
|
||||
.. code-block::
|
||||
|
||||
.
|
||||
└── libri_tts_dataset_directory
|
||||
├── dev-clean
|
||||
│ ├── 116
|
||||
│ │ ├── 288045
|
||||
| | | ├── 116_288045.trans.tsv
|
||||
│ │ │ ├── 116_288045_000003_000000.wav
|
||||
│ │ │ └──...
|
||||
│ │ ├── 288046
|
||||
| | | ├── 116_288046.trans.tsv
|
||||
| | | ├── 116_288046_000003_000000.wav
|
||||
│ | | └── ...
|
||||
| | └── ...
|
||||
│ ├── 1255
|
||||
│ │ ├── 138279
|
||||
| | | ├── 1255_138279.trans.tsv
|
||||
│ │ │ ├── 1255_138279_000001_000000.wav
|
||||
│ │ │ └── ...
|
||||
│ │ ├── 74899
|
||||
| | | ├── 1255_74899.trans.tsv
|
||||
| | | ├── 1255_74899_000001_000000.wav
|
||||
│ | | └── ...
|
||||
| | └── ...
|
||||
| └── ...
|
||||
└── ...
|
||||
|
||||
Citation:
|
||||
|
||||
.. code-block::
|
||||
|
||||
@article{lecun2010mnist,
|
||||
title = {LIBRITTS handwritten digit database},
|
||||
author = {zpw, NBU},
|
||||
journal = {ATT Labs [Online]},
|
||||
volume = {2},
|
||||
year = {2010},
|
||||
howpublished = {http://www.openslr.org/resources/60/},
|
||||
description = {The LibriSpeech ASR corpus (http://www.openslr.org/12/) [1] has been used in
|
||||
various research projects. However, as it was originally designed for ASR research,
|
||||
there are some undesired properties when using for TTS research}
|
||||
}
|
||||
"""
|
||||
|
||||
@check_libri_tts_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
|
||||
sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.LibriTTSNode(self.dataset_dir, self.usage, self.sampler)
|
||||
|
||||
|
||||
class LJSpeechDataset(MappableDataset, AudioBaseDataset):
|
||||
"""
|
||||
A source dataset that reads and parses LJSpeech dataset.
|
||||
|
|
|
@ -729,6 +729,35 @@ def check_celebadataset(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_libri_tts_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(LibriTTSDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100",
|
||||
"train-clean-360", "train-other-500", "all"], "usage")
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_lj_speech_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(LJSpeechDataset)."""
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_iwslt_test.cc
|
||||
c_api_dataset_kmnist_test.cc
|
||||
c_api_dataset_libri_tts.cc
|
||||
c_api_dataset_lj_speech_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
c_api_dataset_minddata_test.cc
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
#include "include/dataset/transforms.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS
|
||||
/// Expectation: get correct LibriTTS dataset
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSBasic.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testLibriTTSData";
|
||||
std::shared_ptr<Dataset> ds = LibriTTS(folder_path);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
uint64_t i = 0;
|
||||
|
||||
while (row.size() != 0) {
|
||||
auto waveform = row["waveform"];
|
||||
auto sample_rate = row["sample_rate"];
|
||||
auto original_text = row["original_text"];
|
||||
auto normalized_text = row["normalized_text"];
|
||||
auto speaker_id = row["speaker_id"];
|
||||
auto chapter_id = row["chapter_id"];
|
||||
auto utterance_id = row["utterance_id"];
|
||||
i++;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
EXPECT_EQ(i, 3);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS with Pipeline
|
||||
/// Expectation: get correct LibriTTS dataset
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSBasicWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestLibriTTSBasicWithPipeline.";
|
||||
|
||||
// Create a LibriTTSDataset Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testLibriTTSData";
|
||||
std::shared_ptr<Dataset> ds = LibriTTS(folder_path, "train-clean-100", std::make_shared<SequentialSampler>(0, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
auto op = transforms::PadEnd({1, 500000});
|
||||
std::vector<std::string> input_columns = {"waveform"};
|
||||
std::vector<std::string> output_columns = {"waveform"};
|
||||
std::vector<std::string> project_columns = {"sample_rate", "original_text", "normalized_text", "speaker_id",
|
||||
"chapter_id", "utterance_id", "waveform"};
|
||||
ds = ds->Map({op}, input_columns, output_columns, project_columns);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->Repeat(5);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->Batch(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
iter->GetNextRow(&row);
|
||||
std::vector<std::string> expected_original_text = {"good morning", "good afternoon"};
|
||||
std::vector<std::string> expected_normalized_text = {"Good morning", "Good afternoon"};
|
||||
std::vector<uint32_t> expected_speaker_id = {2506, 2506};
|
||||
std::vector<uint32_t> expected_chapter_id = {11267, 11267};
|
||||
std::vector<std::string> expected_utterance_id = {"2506_11267_000001_000000", "2506_11267_000002_000000"};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto waveform = row["waveform"];
|
||||
auto original_text = row["original_text"];
|
||||
auto normalized_text = row["normalized_text"];
|
||||
auto sample_rate = row["sample_rate"];
|
||||
auto speaker_id = row["speaker_id"];
|
||||
auto chapter_id = row["chapter_id"];
|
||||
auto utterance_id = row["utterance_id"];
|
||||
|
||||
std::shared_ptr<Tensor> de_original_text;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected_original_text, &de_original_text));
|
||||
mindspore::MSTensor fix_original_text =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_original_text));
|
||||
EXPECT_MSTENSOR_EQ(original_text, fix_original_text);
|
||||
|
||||
std::shared_ptr<Tensor> de_normalized_text;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected_normalized_text, &de_normalized_text));
|
||||
mindspore::MSTensor fix_normalized_text =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_normalized_text));
|
||||
EXPECT_MSTENSOR_EQ(normalized_text, fix_normalized_text);
|
||||
|
||||
std::shared_ptr<Tensor> de_expected_speaker_id;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected_speaker_id, &de_expected_speaker_id));
|
||||
mindspore::MSTensor fix_expected_speaker_id =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_speaker_id));
|
||||
EXPECT_MSTENSOR_EQ(speaker_id, fix_expected_speaker_id);
|
||||
|
||||
std::shared_ptr<Tensor> de_expected_chapter_id;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected_chapter_id, &de_expected_chapter_id));
|
||||
mindspore::MSTensor fix_expected_chapter_id =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_chapter_id));
|
||||
EXPECT_MSTENSOR_EQ(chapter_id, fix_expected_chapter_id);
|
||||
|
||||
std::shared_ptr<Tensor> de_expected_utterance_id;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected_utterance_id, &de_expected_utterance_id));
|
||||
mindspore::MSTensor fix_expected_utterance_id =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_utterance_id));
|
||||
EXPECT_MSTENSOR_EQ(utterance_id, fix_expected_utterance_id);
|
||||
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 5);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS with invalid directory
|
||||
/// Expectation: get correct LibriTTS dataset
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSError.";
|
||||
|
||||
// Create a LibriTTS Dataset with non-existing dataset dir
|
||||
std::shared_ptr<Dataset> ds0 = LibriTTS("NotExistFile");
|
||||
EXPECT_NE(ds0, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
|
||||
// Expect failure: invalid LibriTTS input
|
||||
EXPECT_EQ(iter0, nullptr);
|
||||
|
||||
// Create a LibriTTS Dataset with invalid string of dataset dir
|
||||
std::shared_ptr<Dataset> ds1 = LibriTTS(":*?\"<>|`&;'");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid LibriTTS input
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS with Getters
|
||||
/// Expectation: dataset is null
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSGetters.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testLibriTTSData";
|
||||
// Create a LibriTTS Dataset.
|
||||
std::shared_ptr<Dataset> ds1 = LibriTTS(folder_path);
|
||||
std::shared_ptr<Dataset> ds2 = LibriTTS(folder_path, "train-clean-100");
|
||||
|
||||
std::vector<std::string> column_names = {"waveform", "sample_rate", "original_text", "normalized_text",
|
||||
"speaker_id", "chapter_id", "utterance_id"};
|
||||
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_EQ(ds1->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ds1->GetColumnNames(), column_names);
|
||||
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
EXPECT_EQ(ds2->GetDatasetSize(), 3);
|
||||
EXPECT_EQ(ds2->GetColumnNames(), column_names);
|
||||
}
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS dataset with invalid type
|
||||
/// Expectation: dataset is null
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSWithInvalidUsageError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSWithInvalidUsageError.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testLibriTTSData";
|
||||
// Create a LibriTTS Dataset.
|
||||
std::shared_ptr<Dataset> ds1 = LibriTTS(folder_path, "----");
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
|
||||
// Expect failure: invalid LibriTTS input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
|
||||
std::shared_ptr<Dataset> ds2 = LibriTTS(folder_path, "csacs");
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
// Expect failure: invalid LibriTTS input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter2, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS dataset with null sampler
|
||||
/// Expectation: dataset is null
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSWithNullSamplerError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSWithNullSamplerError.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testLibriTTSData";
|
||||
// Create a LibriTTS Dataset.
|
||||
std::shared_ptr<Dataset> ds = LibriTTS(folder_path, "all", nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid LibriTTS input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: LibriTTSDataset
|
||||
/// Description: test LibriTTS with sequential sampler
|
||||
/// Expectation: get correct LibriTTS dataset
|
||||
TEST_F(MindDataTestPipeline, TestLibriTTSSequentialSamplers) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSSequentialSamplers.";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testLibriTTSData";
|
||||
std::shared_ptr<Dataset> ds = LibriTTS(folder_path, "all", std::make_shared<SequentialSampler>(0, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row.
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
std::string_view original_text_idx, normalized_text_idx, utterance_id_idx;
|
||||
uint32_t speaker_idx_id = 0, chapter_idx_id = 0;
|
||||
std::vector<std::string> expected_original_text = {"good morning", "good afternoon"};
|
||||
std::vector<std::string> expected_normalized_text = {"Good morning", "Good afternoon"};
|
||||
std::vector<uint32_t> expected_speaker_id = {2506, 2506};
|
||||
std::vector<uint32_t> expected_chapter_id = {11267, 11267};
|
||||
std::vector<std::string> expected_utterance_id = {"2506_11267_000001_000000", "2506_11267_000002_000000"};
|
||||
uint32_t rate = 0;
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto waveform = row["waveform"];
|
||||
auto sample_rate = row["sample_rate"];
|
||||
auto original_text = row["original_text"];
|
||||
auto normalized_text = row["normalized_text"];
|
||||
auto speaker_id = row["speaker_id"];
|
||||
auto chapter_id = row["chapter_id"];
|
||||
auto utterance_id = row["utterance_id"];
|
||||
|
||||
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
|
||||
|
||||
std::shared_ptr<Tensor> trate;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate));
|
||||
ASSERT_OK(trate->GetItemAt<uint32_t>(&rate, {}));
|
||||
EXPECT_EQ(rate, 24000);
|
||||
|
||||
std::shared_ptr<Tensor> de_original_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(original_text, &de_original_text));
|
||||
ASSERT_OK(de_original_text->GetItemAt(&original_text_idx, {}));
|
||||
std::string s_original_text(original_text_idx);
|
||||
EXPECT_STREQ(s_original_text.c_str(), expected_original_text[i].c_str());
|
||||
|
||||
std::shared_ptr<Tensor> de_normalized_text;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(normalized_text, &de_normalized_text));
|
||||
ASSERT_OK(de_normalized_text->GetItemAt(&normalized_text_idx, {}));
|
||||
std::string s_normalized_text(normalized_text_idx);
|
||||
EXPECT_STREQ(s_normalized_text.c_str(), expected_normalized_text[i].c_str());
|
||||
|
||||
std::shared_ptr<Tensor> de_speaker_id;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(speaker_id, &de_speaker_id));
|
||||
ASSERT_OK(de_speaker_id->GetItemAt<uint32_t>(&speaker_idx_id, {}));
|
||||
EXPECT_EQ(speaker_idx_id, expected_speaker_id[i]);
|
||||
|
||||
std::shared_ptr<Tensor> de_chapter_id;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(chapter_id, &de_chapter_id));
|
||||
ASSERT_OK(de_chapter_id->GetItemAt<uint32_t>(&chapter_idx_id, {}));
|
||||
EXPECT_EQ(chapter_idx_id, expected_chapter_id[i]);
|
||||
|
||||
std::shared_ptr<Tensor> de_utterance_id;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id));
|
||||
ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {}));
|
||||
std::string s_utterance_id(utterance_id_idx);
|
||||
EXPECT_STREQ(s_utterance_id.c_str(), expected_utterance_id[i].c_str());
|
||||
|
||||
i++;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
2506_11267_000001_000000 good morning Good morning
|
||||
2506_11267_000002_000000 good afternoon Good afternoon
|
||||
2506_11267_000003_000001 good evening Good evening
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,235 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License foNtest_resr the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test LibriTTS dataset operators
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/testLibriTTSData"
|
||||
|
||||
|
||||
def test_libri_tts_basic():
|
||||
"""
|
||||
Feature: LibriTTSDataset
|
||||
Description: test basic usage of LibriTTS
|
||||
Expectation: the dataset is as expected
|
||||
"""
|
||||
logger.info("Test LibriTTSDataset Op")
|
||||
|
||||
# case 1: test loading fault dataset.
|
||||
data1 = ds.LibriTTSDataset(DATA_DIR)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 3
|
||||
|
||||
# case 2: test num_samples.
|
||||
data2 = ds.LibriTTSDataset(DATA_DIR, num_samples=1)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 1
|
||||
|
||||
# case 3: test repeat.
|
||||
data3 = ds.LibriTTSDataset(DATA_DIR, usage="all", num_samples=3)
|
||||
data3 = data3.repeat(3)
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 9
|
||||
|
||||
# case 4: test batch with drop_remainder=False.
|
||||
data4 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", num_samples=3)
|
||||
assert data4.get_dataset_size() == 3
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.batch(batch_size=2) # drop_remainder is default to be False.
|
||||
assert data4.get_dataset_size() == 2
|
||||
assert data4.get_batch_size() == 2
|
||||
|
||||
# case 5: test batch with drop_remainder=True.
|
||||
data5 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", num_samples=3)
|
||||
assert data5.get_dataset_size() == 3
|
||||
assert data5.get_batch_size() == 1
|
||||
# the rest of incomplete batch will be dropped.
|
||||
data5 = data5.batch(batch_size=2, drop_remainder=True)
|
||||
assert data5.get_dataset_size() == 1
|
||||
assert data5.get_batch_size() == 2
|
||||
|
||||
|
||||
def test_libri_tts_distribute_sampler():
|
||||
"""
|
||||
Feature: LibriTTSDataset
|
||||
Description: test LibriTTS dataset with DisributeSampler
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
logger.info("Test LibriTTS with sharding")
|
||||
|
||||
list1, list2 = [], []
|
||||
num_shards = 3
|
||||
shard_id = 0
|
||||
|
||||
data1 = ds.LibriTTSDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id)
|
||||
count = 0
|
||||
for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
list1.append(item1["original_text"])
|
||||
count = count + 1
|
||||
assert count == 1
|
||||
|
||||
num_shards = 3
|
||||
shard_id = 0
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id)
|
||||
data2 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", sampler=sampler)
|
||||
count = 0
|
||||
for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
list2.append(item2["original_text"])
|
||||
count = count + 1
|
||||
assert count == 1
|
||||
|
||||
|
||||
def test_libri_tts_exception():
|
||||
"""
|
||||
Feature: LibriTTSDataset
|
||||
Description: test error cases for LibriTTSDataset
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
logger.info("Test error cases for LibriTTSDataset")
|
||||
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.LibriTTSDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.LibriTTSDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.LibriTTSDataset(DATA_DIR, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.LibriTTSDataset(DATA_DIR, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.LibriTTSDataset(DATA_DIR, num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.LibriTTSDataset(DATA_DIR, num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.LibriTTSDataset(DATA_DIR, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.LibriTTSDataset(DATA_DIR, num_shards=2, shard_id="0")
|
||||
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
error_msg_8 = "The corresponding data files"
|
||||
with pytest.raises(RuntimeError, match=error_msg_8):
|
||||
data = ds.LibriTTSDataset(DATA_DIR)
|
||||
data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
pass
|
||||
|
||||
|
||||
def test_libri_tts_sequential_sampler():
|
||||
"""
|
||||
Feature: LibriTTSDataset
|
||||
Description: test LibriTTSDataset with SequentialSampler
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
logger.info("Test LibriTTSDataset Op with SequentialSampler")
|
||||
|
||||
num_samples = 2
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", sampler=sampler)
|
||||
data2 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", shuffle=False, num_samples=num_samples)
|
||||
list1, list2 = [], []
|
||||
list_expected = [24000, b'good morning', b'Good morning', 2506, 11267, b'2506_11267_000001_000000',
|
||||
24000, b'good afternoon', b'Good afternoon', 2506, 11267, b'2506_11267_000002_000000']
|
||||
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1),
|
||||
data2.create_dict_iterator(output_numpy=True, num_epochs=1)):
|
||||
list1.append(item1["sample_rate"])
|
||||
list2.append(item2["sample_rate"])
|
||||
list1.append(item1["original_text"])
|
||||
list2.append(item2["original_text"])
|
||||
list1.append(item1["normalized_text"])
|
||||
list2.append(item2["normalized_text"])
|
||||
list1.append(item1["speaker_id"])
|
||||
list2.append(item2["speaker_id"])
|
||||
list1.append(item1["chapter_id"])
|
||||
list2.append(item2["chapter_id"])
|
||||
list1.append(item1["utterance_id"])
|
||||
list2.append(item2["utterance_id"])
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(list1, list_expected)
|
||||
np.testing.assert_array_equal(list2, list_expected)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_libri_tts_usage():
|
||||
"""
|
||||
Feature: LibriTTSDataset
|
||||
Description: test LibriTTSDataset usage
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
logger.info("Test LibriTTSDataset usage")
|
||||
|
||||
def test_config(usage, libri_tts_path=None):
|
||||
libri_tts_path = DATA_DIR if libri_tts_path is None else libri_tts_path
|
||||
try:
|
||||
data = ds.LibriTTSDataset(libri_tts_path, usage=usage, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config("all") == 3
|
||||
assert test_config("train-clean-100") == 3
|
||||
assert "Input usage is not within the valid set of ['dev-clean', 'dev-other', 'test-clean', 'test-other', " \
|
||||
"'train-clean-100', 'train-clean-360', 'train-other-500', 'all']." in test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||
|
||||
all_files_path = None
|
||||
if all_files_path is not None:
|
||||
assert test_config("train-clean-100", all_files_path) == 3
|
||||
assert ds.LibriTTSDataset(all_files_path, usage="train-clean-100").get_dataset_size() == 3
|
||||
assert test_config("all", all_files_path) == 3
|
||||
assert ds.LibriTTSDataset(all_files_path, usage="all").get_dataset_size() == 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_libri_tts_basic()
|
||||
test_libri_tts_distribute_sampler()
|
||||
test_libri_tts_exception()
|
||||
test_libri_tts_sequential_sampler()
|
||||
test_libri_tts_usage()
|
Loading…
Reference in New Issue