!20576 [assistant][ops] add new dataset loading operator LJSpeechDataset

Merge pull request !20576 from 杨旭华/LJSpeechDataset
This commit is contained in:
i-robot 2021-11-16 07:00:20 +00:00 committed by Gitee
commit 58b69a05ee
26 changed files with 1159 additions and 1 deletions

View File

@ -103,6 +103,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
@ -1209,6 +1210,27 @@ ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, boo
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,

View File

@ -47,6 +47,7 @@
// IR leaf nodes disabled for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
@ -264,6 +265,16 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
(void)py::class_<LJSpeechNode, DatasetNode, std::shared_ptr<LJSpeechNode>>(*m, "LJSpeechNode",
"to create a LJSpeechNode")
.def(py::init([](std::string dataset_dir, py::handle sampler) {
auto lj_speech = std::make_shared<LJSpeechNode>(dataset_dir, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(lj_speech->ValidateParams());
return lj_speech;
}));
}));
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode")

View File

@ -16,10 +16,13 @@
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include <fstream>
#include "mindspore/core/base/float16.h"
#include "minddata/dataset/core/type_id.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/random.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
@ -850,5 +853,43 @@ Status GenerateWaveTable(std::shared_ptr<Tensor> *output, const DataType &type,
return Status::OK();
}
Status ReadWaveFile(const std::string &wav_file_dir, std::vector<float> *waveform_vec, int32_t *sample_rate) {
RETURN_UNEXPECTED_IF_NULL(waveform_vec);
RETURN_UNEXPECTED_IF_NULL(sample_rate);
auto wav_realpath = FileUtils::GetRealPath(wav_file_dir.data());
if (!wav_realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << wav_file_dir;
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + wav_file_dir);
}
const float kMaxVal = 32767.0;
const int kDataMove = 2;
Path file_path(wav_realpath.value());
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(),
"Invalid file, failed to find metadata file:" + file_path.ToString());
std::ifstream in(file_path.ToString(), std::ios::in | std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), "Invalid file, failed to open metadata file:" + file_path.ToString() +
", make sure the file not damaged or permission denied.");
WavHeader *header = new WavHeader();
in.read(reinterpret_cast<char *>(header), sizeof(WavHeader));
*sample_rate = header->sampleRate;
std::unique_ptr<char[]> data = std::make_unique<char[]>(header->subChunk2Size);
in.read(data.get(), header->subChunk2Size);
float bytesPerSample = header->bitsPerSample / 8;
if (bytesPerSample == 0) {
in.close();
delete header;
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "ReadWaveFile: divide zero error.");
}
int numSamples = header->subChunk2Size / bytesPerSample;
waveform_vec->resize(numSamples);
for (int i = 0; i < numSamples; i++) {
(*waveform_vec)[i] = static_cast<int16_t>(data[kDataMove * i] / kMaxVal);
}
in.close();
delete header;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -989,6 +989,31 @@ Status Flanger(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *out
RETURN_IF_NOT_OK(TypeCast(output_waveform, output, input->type()));
return Status::OK();
}
// A brief structure of wave file header.
struct WavHeader {
int8_t chunkID[4] = {0};
int32_t chunkSize = 0;
int8_t format[4] = {0};
int8_t subChunk1ID[4] = {0};
int32_t subChunk1Size = 0;
int16_t audioFormat = 0;
int16_t numChannels = 0;
int32_t sampleRate = 0;
int32_t byteRate = 0;
int16_t byteAlign = 0;
int16_t bitsPerSample = 0;
int8_t subChunk2ID[4] = {0};
int32_t subChunk2Size = 0;
WavHeader() {}
};
/// \brief Get an audio data from a wav file and store into a vector.
/// \param wav_file_dir: wave file dir.
/// \param waveform_vec: vector of waveform.
/// \param sample_rate: sample rate.
/// \return Status code.
Status ReadWaveFile(const std::string &wav_file_dir, std::vector<float> *waveform_vec, int32_t *sample_rate);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_

View File

@ -24,6 +24,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
qmnist_op.cc
emnist_op.cc
fake_image_op.cc
lj_speech_op.cc
places365_op.cc
photo_tour_op.cc
fashion_mnist_op.cc

View File

@ -0,0 +1,153 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/lj_speech_op.h"
#include <fstream>
#include <iomanip>
#include <utility>
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/path.h"
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
LJSpeechOp::LJSpeechOp(const std::string &file_dir, int32_t num_workers, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
folder_path_(file_dir),
data_schema_(std::move(data_schema)) {}
Status LJSpeechOp::PrepareData() {
auto real_path = FileUtils::GetRealPath(folder_path_.data());
if (!real_path.has_value()) {
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + folder_path_);
}
Path root_folder(real_path.value());
Path metadata_file_path = root_folder / "metadata.csv";
CHECK_FAIL_RETURN_UNEXPECTED(metadata_file_path.Exists() && !metadata_file_path.IsDirectory(),
"Invalid file, failed to find metadata file: " + metadata_file_path.ToString());
std::ifstream csv_reader(metadata_file_path.ToString());
CHECK_FAIL_RETURN_UNEXPECTED(csv_reader.is_open(),
"Invalid file, failed to open metadata file: " + metadata_file_path.ToString() +
", make sure file not damaged or permission denied.");
std::string line = "";
while (getline(csv_reader, line)) {
int32_t last_pos = 0, curr_pos = 0;
std::vector<std::string> row;
while (curr_pos < line.size()) {
if (line[curr_pos] == '|') {
row.emplace_back(line.substr(last_pos, curr_pos - last_pos));
last_pos = curr_pos + 1;
}
++curr_pos;
}
row.emplace_back(line.substr(last_pos, curr_pos - last_pos));
meta_info_list_.emplace_back(row);
}
if (meta_info_list_.empty()) {
csv_reader.close();
RETURN_STATUS_UNEXPECTED(
"Reading failed, unable to read valid data from the metadata file: " + metadata_file_path.ToString() + ".");
}
num_rows_ = meta_info_list_.size();
csv_reader.close();
return Status::OK();
}
// Load 1 TensorRow (waveform, sample_rate, transcription, normalized_transcription).
// 1 function call produces 1 TensorTow
Status LJSpeechOp::LoadTensorRow(row_id_type index, TensorRow *trow) {
int32_t num_items = meta_info_list_.size();
CHECK_FAIL_RETURN_UNEXPECTED(index >= 0 && index < num_items, "The input index is out of range.");
std::shared_ptr<Tensor> waveform;
std::shared_ptr<Tensor> sample_rate_scalar;
std::shared_ptr<Tensor> transcription, normalized_transcription;
std::string file_name_pref = meta_info_list_[index][0], transcription_str = meta_info_list_[index][1],
normalized_transcription_str = meta_info_list_[index][2];
int32_t sample_rate;
std::string file_name = file_name_pref + ".wav";
Path root_folder(folder_path_);
Path wav_file_path = root_folder / "wavs" / file_name;
Path metadata_file_path = root_folder / "metadata.csv";
std::vector<float> waveform_vec;
RETURN_IF_NOT_OK(ReadWaveFile(wav_file_path.ToString(), &waveform_vec, &sample_rate));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, &waveform));
RETURN_IF_NOT_OK(waveform->ExpandDim(0));
RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_scalar));
RETURN_IF_NOT_OK(Tensor::CreateScalar(transcription_str, &transcription));
RETURN_IF_NOT_OK(Tensor::CreateScalar(normalized_transcription_str, &normalized_transcription));
(*trow) = TensorRow(index, {waveform, sample_rate_scalar, transcription, normalized_transcription});
// Add file path info
trow->setPath({wav_file_path.ToString(), metadata_file_path.ToString(), metadata_file_path.ToString(),
metadata_file_path.ToString()});
return Status::OK();
}
void LJSpeechOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows: " << num_rows_ << "\nLJSpeech directory: " << folder_path_ << "\n\n";
}
}
Status LJSpeechOp::CountTotalRows(const std::string &dir, int64_t *count) {
auto real_path = FileUtils::GetRealPath(dir.data());
if (!real_path.has_value()) {
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + dir);
}
Path root_folder(real_path.value());
Path metadata_file_path = root_folder / "metadata.csv";
CHECK_FAIL_RETURN_UNEXPECTED(metadata_file_path.Exists() && !metadata_file_path.IsDirectory(),
"Invalid file, failed to find metadata file: " + metadata_file_path.ToString());
std::ifstream csv_reader(metadata_file_path.ToString());
CHECK_FAIL_RETURN_UNEXPECTED(csv_reader.is_open(),
"Invalid file, failed to open metadata file: " + metadata_file_path.ToString() +
", make sure file not damaged or permission denied.");
std::string line = "";
int64_t cnt = 0;
while (getline(csv_reader, line)) {
++cnt;
}
*count = cnt;
csv_reader.close();
return Status::OK();
}
Status LJSpeechOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LJ_SPEECH_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LJ_SPEECH_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// \brief Read LJSpeech dataset.
class LJSpeechOp : public MappableLeafOp {
public:
/// \brief Constructor.
/// \param[in] file_dir Directory of lj_speech dataset.
/// \param[in] num_workers Number of workers reading audios in parallel.
/// \param[in] queue_size Connector queue size.
/// \param[in] data_schema Data schema of lj_speech dataset.
/// \param[in] sampler Sampler tells LJSpeechOp what to read.
LJSpeechOp(const std::string &file_dir, int32_t num_workers, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor.
~LJSpeechOp() = default;
/// \brief A print method typically used for debugging.
/// \param[out] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Function to count the number of samples in the LJSpeech dataset.
/// \param[in] dir Path to the directory of LJSpeech dataset.
/// \param[out] count Output arg that will hold the actual dataset size.
/// \return Status
static Status CountTotalRows(const std::string &dir, int64_t *count);
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "LJSpeechOp"; }
protected:
/// \brief Called first when function is called.
/// \return Status
Status PrepareData() override;
private:
/// \brief Load a tensor row.
/// \param[in] index Index need to load.
/// \param[out] trow Waveform & sample_rate & transcription & normalized_transcription read into this tensor row.
/// \return Status the status code returned.
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
/// \brief Private function for computing the assignment of the column name map.
/// \return Status
Status ComputeColMap() override;
std::string folder_path_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::vector<std::string>> meta_info_list_; // the shape is (N, 3)
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LJ_SPEECH_OP_H_

View File

@ -91,6 +91,7 @@ constexpr char kFashionMnistNode[] = "FashionMnistDataset";
constexpr char kFlickrNode[] = "FlickrDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kLJSpeechNode[] = "LJSpeechDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";

View File

@ -19,6 +19,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
fashion_mnist_node.cc
flickr_node.cc
image_folder_node.cc
lj_speech_node.cc
manifest_node.cc
minddata_node.cc
mnist_node.cc

View File

@ -0,0 +1,117 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include <utility>
#include "minddata/dataset/engine/datasetops/source/lj_speech_op.h"
namespace mindspore {
namespace dataset {
// Constructor for LJSpeechNode.
LJSpeechNode::LJSpeechNode(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler) {}
std::shared_ptr<DatasetNode> LJSpeechNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<LJSpeechNode>(dataset_dir_, sampler, cache_);
return node;
}
void LJSpeechNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
}
Status LJSpeechNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("LJSpeechNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("LJSpeechNode", sampler_));
return Status::OK();
}
// Function to build LJSpeechOp for LJSpeech.
Status LJSpeechNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
TensorShape sample_rate_scalar = TensorShape::CreateScalar();
TensorShape trans_scalar = TensorShape::CreateScalar();
TensorShape nom_trans_scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("transcription", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar)));
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_transcription", DataType(DataType::DE_STRING),
TensorImpl::kFlexible, 0, &nom_trans_scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto lj_speech_op = std::make_shared<LJSpeechOp>(dataset_dir_, num_workers_, connector_que_size_, std::move(schema),
std::move(sampler_rt));
lj_speech_op->SetTotalRepeats(GetTotalRepeats());
lj_speech_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(lj_speech_op);
return Status::OK();
}
// Get the shard id of node.
Status LJSpeechNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size.
Status LJSpeechNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size = 0;
RETURN_IF_NOT_OK(LJSpeechOp::CountTotalRows(dataset_dir_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status LJSpeechNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LJ_SPEECH_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LJ_SPEECH_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \brief Read LJSpeech dataset.
class LJSpeechNode : public MappableSourceNode {
public:
/// \brief Constructor.
LJSpeechNode(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~LJSpeechNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kLJSpeechNode; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
/// \return Path string of the dataset.
const std::string &DatasetDir() const { return dataset_dir_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LJ_SPEECH_NODE_H_

View File

@ -2506,6 +2506,75 @@ inline std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &datase
MapStringToChar(class_indexing), cache);
}
/// \class LJSpeechDataset
/// \brief A source dataset for reading and parsing LJSpeech dataset.
class LJSpeechDataset : public Dataset {
public:
/// \brief Constructor of LJSpeechDataset.
/// \param[in] dataset_file The dataset file to be read.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
/// \param[in] cache Tensor cache to use.
LJSpeechDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of LJSpeechDataset.
/// \param[in] dataset_file The dataset file to be read.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of LJSpeechDataset.
/// \param[in] dataset_file The dataset file to be read.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
LJSpeechDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of LJSpeechDataset.
~LJSpeechDataset() = default;
};
/// \brief Function to create a LJSpeech Dataset.
/// \notes The generated dataset has four columns ["waveform", "sample_rate", "transcription",
/// "normalized_transcription"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use. (default=nullptr, which means no cache is used).
/// \return Shared pointer to the current Dataset.
inline std::shared_ptr<LJSpeechDataset> LJSpeech(
const std::string &dataset_dir, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
/// \brief Function to create a LJSpeech Dataset.
/// \notes The generated dataset has four columns ["waveform", "sample_rate", "transcription",
/// "normalized_transcription"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. (default=nullptr, which means no cache is used).
/// \return Shared pointer to the current Dataset.
inline std::shared_ptr<LJSpeechDataset> LJSpeech(const std::string &dataset_dir, Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
/// \brief Function to create a LJSpeech Dataset.
/// \notes The generated dataset has four columns ["waveform", "sample_rate", "transcription",
/// "normalized_transcription"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use. (default=nullptr, which means no cache is used).
/// \return Shared pointer to the current Dataset.
inline std::shared_ptr<LJSpeechDataset> LJSpeech(const std::string &dataset_dir,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
/// \class ManifestDataset
/// \brief A source dataset for reading and parsing Manifest dataset.
class ManifestDataset : public Dataset {

View File

@ -44,6 +44,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
friend class FashionMnistDataset;
friend class FlickrDataset;
friend class ImageFolderDataset;
friend class LJSpeechDataset;
friend class ManifestDataset;
friend class MindDataDataset;
friend class MnistDataset;

View File

@ -68,7 +68,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -6843,6 +6843,142 @@ class Flowers102Dataset(GeneratorDataset):
return class_dict
class LJSpeechDataset(MappableDataset):
"""
A source dataset for reading and parsing LJSpeech dataset.
The generated dataset has four columns :py:obj:`[waveform, sample_rate, transcription, normalized_transcript]`.
The tensor of column :py:obj:`waveform` is a tensor of the float32 type.
The tensor of column :py:obj:`sample_rate` is a scalar of the int32 type.
The tensor of column :py:obj:`transcription` is a scalar of the string type.
The tensor of column :py:obj:`normalized_transcript` is a scalar of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
num_samples (int, optional): The number of audios to be included in the dataset
(default=None, all audios).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, `num_samples` reflects
the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> lj_speech_dataset_dir = "/path/to/lj_speech_dataset_directory"
>>>
>>> # 1) Get all samples from LJSPEECH dataset in sequence
>>> dataset = ds.LJSpeechDataset(dataset_dir=lj_speech_dataset_dir, shuffle=False)
>>>
>>> # 2) Randomly select 350 samples from LJSPEECH dataset
>>> dataset = ds.LJSpeechDataset(dataset_dir=lj_speech_dataset_dir, num_samples=350, shuffle=True)
>>>
>>> # 3) Get samples from LJSPEECH dataset for shard 0 in a 2-way distributed training
>>> dataset = ds.LJSpeechDataset(dataset_dir=lj_speech_dataset_dir, num_shards=2, shard_id=0)
>>>
>>> # In LJSPEECH dataset, each dictionary has keys "waveform", "sample_rate", "transcription"
>>> # and "normalized_transcript"
About LJSPEECH dataset:
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker
reading passages from 7 non-fiction books. A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the LibriVox project and is also in the public domain.
Here is the original LJSPEECH dataset structure.
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
.. code-block::
.
LJSpeech-1.1
README
metadata.csv
wavs
LJ001-0001.wav
LJ001-0002.wav
LJ001-0003.wav
LJ001-0004.wav
LJ001-0005.wav
LJ001-0006.wav
LJ001-0007.wav
LJ001-0008.wav
...
LJ050-0277.wav
LJ050-0278.wav
Citation:
.. code-block::
@misc{lj_speech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {url{https://keithito.com/LJ-Speech-Dataset}},
year = 2017
}
"""
@check_lj_speech_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
def parse(self, children=None):
return cde.LJSpeechNode(self.dataset_dir, self.sampler)
class TextFileDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in text format.

View File

@ -421,6 +421,32 @@ def check_celebadataset(method):
return new_method
def check_lj_speech_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(LJSpeechDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_save(method):
"""A wrapper that wraps a parameter checker around the saved operator."""

View File

@ -29,6 +29,7 @@ SET(DE_UT_SRCS
c_api_dataset_fashion_mnist_test.cc
c_api_dataset_flickr_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_lj_speech_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_ops_test.cc

View File

@ -0,0 +1,205 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: LJSpeechDataset
/// Description: basic test of LJSpeechDataset
/// Expectation: the data is processed successfully
TEST_F(MindDataTestPipeline, TestLJSpeechDataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLJSpeechDataset.";
std::string folder_path = datasets_root_path_ + "/testLJSpeechData/";
std::shared_ptr<Dataset> ds = LJSpeech(folder_path, std::make_shared<RandomSampler>(false, 3));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
MS_LOG(INFO) << "iter->GetNextRow(&row) OK";
EXPECT_NE(row.find("waveform"), row.end());
EXPECT_NE(row.find("sample_rate"), row.end());
EXPECT_NE(row.find("transcription"), row.end());
EXPECT_NE(row.find("normalized_transcription"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto waveform = row["waveform"];
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LJSpeechDataset
/// Description: test LJSpeechDataset in pipeline mode
/// Expectation: the data is processed successfully
TEST_F(MindDataTestPipeline, TestLJSpeechDatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLJSpeechDatasetWithPipeline.";
// Create two LJSpeech Dataset.
std::string folder_path = datasets_root_path_ + "/testLJSpeechData/";
std::shared_ptr<Dataset> ds1 = LJSpeech(folder_path, std::make_shared<RandomSampler>(false, 3));
std::shared_ptr<Dataset> ds2 = LJSpeech(folder_path, std::make_shared<RandomSampler>(false, 3));
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 1;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 1;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds.
std::vector<std::string> column_project = {"waveform", "sample_rate", "transcription", "normalized_transcription"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("waveform"), row.end());
EXPECT_NE(row.find("sample_rate"), row.end());
EXPECT_NE(row.find("transcription"), row.end());
EXPECT_NE(row.find("normalized_transcription"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto waveform = row["waveform"];
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 6);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: LJSpeechDataset
/// Description: test getting size of LJSpeechDataset
/// Expectation: the size is correct
TEST_F(MindDataTestPipeline, TestLJSpeechGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLJSpeechGetDatasetSize.";
// Create a LJSpeech Dataset.
std::string folder_path = datasets_root_path_ + "/testLJSpeechData/";
std::shared_ptr<Dataset> ds = LJSpeech(folder_path);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
}
/// Feature: LJSpeechDataset
/// Description: test LJSpeechDataset with mix getter
/// Expectation: the data is processed successfully
TEST_F(MindDataTestPipeline, TestLJSpeechGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLJSpeechMixGetter.";
// Create a LJSpeech Dataset.
std::string folder_path = datasets_root_path_ + "/testLJSpeechData/";
std::shared_ptr<Dataset> ds = LJSpeech(folder_path);
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 3);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
std::vector<std::string> column_names = {"waveform", "sample_rate", "transcription", "normalized_transcription"};
EXPECT_EQ(types.size(), 4);
EXPECT_EQ(types[0].ToString(), "float32");
EXPECT_EQ(types[1].ToString(), "int32");
EXPECT_EQ(types[2].ToString(), "string");
EXPECT_EQ(types[3].ToString(), "string");
EXPECT_EQ(shapes.size(), 4);
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(shapes[2].ToString(), "<>");
EXPECT_EQ(shapes[3].ToString(), "<>");
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 3);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: LJSpeechDataset
/// Description: test LJSpeechDataset with the fail of reading dataset
/// Expectation: throw correct error and message
TEST_F(MindDataTestPipeline, TestLJSpeechDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLJSpeechDatasetFail.";
// Create a LJSpeech Dataset.
std::shared_ptr<Dataset> ds = LJSpeech("", std::make_shared<RandomSampler>(false, 3));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid LJSpeech input.
EXPECT_EQ(iter, nullptr);
}
/// Feature: LJSpeechDataset
/// Description: test LJSpeechDataset with the null sampler
/// Expectation: throw correct error and message
TEST_F(MindDataTestPipeline, TestLJSpeechDatasetWithNullSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLJSpeechDatasetWithNullSamplerFail.";
// Create a LJSpeech Dataset.
std::string folder_path = datasets_root_path_ + "/testLJSpeechData/";
std::shared_ptr<Dataset> ds = LJSpeech(folder_path, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid LJSpeech input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}

View File

@ -0,0 +1,3 @@
my_wave_1|this is my_wave_1|this is my_wave_1
my_wave_2|this is my_wave_2|this is my_wave_2
my_wave_3|this is my_wave_3|this is my_wave_3
1 my_wave_1 this is my_wave_1 this is my_wave_1
2 my_wave_2 this is my_wave_2 this is my_wave_2
3 my_wave_3 this is my_wave_3 this is my_wave_3

View File

@ -0,0 +1,3 @@
my_wave_1|this is my_wave_1|this is my_wave_1
my_wave_2|this is my_wave_2|this is my_wave_2
my_wave_3|this is my_wave_3|this is my_wave_3
1 my_wave_1 this is my_wave_1 this is my_wave_1
2 my_wave_2 this is my_wave_2 this is my_wave_2
3 my_wave_3 this is my_wave_3 this is my_wave_3

View File

@ -0,0 +1,161 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test LJSpeech dataset operators
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
DATA_DIR = "../data/dataset/testLJSpeechData/"
def test_lj_speech_basic():
"""
Feature: LJSpeechDataset
Description: basic test of LJSpeechDataset
Expectation: the data is processed successfully
"""
logger.info("Test LJSpeechDataset Op")
# case 1: test loading whole dataset
data1 = ds.LJSpeechDataset(DATA_DIR)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter1 += 1
assert num_iter1 == 3
# case 2: test num_samples
data2 = ds.LJSpeechDataset(DATA_DIR, num_samples=3)
num_iter2 = 0
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter2 += 1
assert num_iter2 == 3
# case 3: test repeat
data3 = ds.LJSpeechDataset(DATA_DIR, num_samples=3)
data3 = data3.repeat(5)
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter3 += 1
assert num_iter3 == 15
def test_lj_speech_sequential_sampler():
"""
Feature: LJSpeechDataset
Description: test LJSpeechDataset with SequentialSampler
Expectation: the data is processed successfully
"""
logger.info("Test LJSpeechDataset Op with SequentialSampler")
num_samples = 3
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.LJSpeechDataset(DATA_DIR, sampler=sampler)
data2 = ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
sample_rate_list1, sample_rate_list2 = [], []
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
sample_rate_list1.append(item1["sample_rate"])
sample_rate_list2.append(item2["sample_rate"])
num_iter += 1
np.testing.assert_array_equal(sample_rate_list1, sample_rate_list2)
assert num_iter == num_samples
def test_lj_speech_exception():
"""
Feature: LJSpeechDataset
Description: test error cases for LJSpeechDataset
Expectation: throw correct error and message
"""
logger.info("Test error cases for LJSpeechDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.LJSpeechDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.LJSpeechDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.LJSpeechDataset(DATA_DIR, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.LJSpeechDataset(DATA_DIR, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.LJSpeechDataset(DATA_DIR, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.LJSpeechDataset(DATA_DIR, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.LJSpeechDataset(DATA_DIR, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.LJSpeechDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.LJSpeechDataset(DATA_DIR, num_shards=2, shard_id="0")
def exception_func(item):
raise Exception("Error occur!")
error_msg_8 = "The corresponding data files"
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.LJSpeechDataset(DATA_DIR)
data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
for _ in data.__iter__():
pass
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.LJSpeechDataset(DATA_DIR)
data = data.map(operations=exception_func, input_columns=["sample_rate"], num_parallel_workers=1)
for _ in data.__iter__():
pass
def test_lj_speech_pipeline():
"""
Feature: LJSpeechDataset
Description: Read a sample
Expectation: The amount of each function are equal
"""
# Original waveform
dataset = ds.LJSpeechDataset(DATA_DIR)
band_biquad_op = audio.BandBiquad(8000, 200.0)
# Filtered waveform by bandbiquad
dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
i = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
i += 1
assert i == 3
if __name__ == '__main__':
test_lj_speech_basic()
test_lj_speech_sequential_sampler()
test_lj_speech_exception()
test_lj_speech_pipeline()