update review problems

This commit is contained in:
Xiao Tianci 2022-01-11 15:31:11 +08:00
parent 1ac5261dfe
commit 96a2e38d01
69 changed files with 1040 additions and 1039 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -92,9 +92,10 @@ std::shared_ptr<TensorOperation> AmplitudeToDB::Parse() {
} }
// Angle Transform Operation. // Angle Transform Operation.
Angle::Angle() {} Angle::Angle() = default;
std::shared_ptr<TensorOperation> Angle::Parse() { return std::make_shared<AngleOperation>(); } std::shared_ptr<TensorOperation> Angle::Parse() { return std::make_shared<AngleOperation>(); }
// BandBiquad Transform Operation. // BandBiquad Transform Operation.
struct BandBiquad::Data { struct BandBiquad::Data {
Data(int32_t sample_rate, float central_freq, float Q, bool noise) Data(int32_t sample_rate, float central_freq, float Q, bool noise)
@ -225,7 +226,7 @@ struct DBToAmplitude::Data {
float power_; float power_;
}; };
DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared<Data>(power, power)) {} DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared<Data>(ref, power)) {}
std::shared_ptr<TensorOperation> DBToAmplitude::Parse() { std::shared_ptr<TensorOperation> DBToAmplitude::Parse() {
return std::make_shared<DBToAmplitudeOperation>(data_->ref_, data_->power_); return std::make_shared<DBToAmplitudeOperation>(data_->ref_, data_->power_);
@ -431,7 +432,7 @@ struct LFilter::Data {
bool clamp_; bool clamp_;
}; };
LFilter::LFilter(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp) LFilter::LFilter(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp)
: data_(std::make_shared<Data>(a_coeffs, b_coeffs, clamp)) {} : data_(std::make_shared<Data>(a_coeffs, b_coeffs, clamp)) {}
std::shared_ptr<TensorOperation> LFilter::Parse() { std::shared_ptr<TensorOperation> LFilter::Parse() {
@ -585,6 +586,17 @@ struct Spectrogram::Data {
bool onesided_; bool onesided_;
}; };
Spectrogram::Spectrogram(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window,
float power, bool normalized, bool center, BorderType pad_mode, bool onesided)
: data_(std::make_shared<Data>(n_fft, win_length, hop_length, pad, window, power, normalized, center, pad_mode,
onesided)) {}
std::shared_ptr<TensorOperation> Spectrogram::Parse() {
return std::make_shared<SpectrogramOperation>(data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_,
data_->window_, data_->power_, data_->normalized_, data_->center_,
data_->pad_mode_, data_->onesided_);
}
// SpectralCentroid Transform Operation. // SpectralCentroid Transform Operation.
struct SpectralCentroid::Data { struct SpectralCentroid::Data {
Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window) Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window)
@ -611,17 +623,6 @@ std::shared_ptr<TensorOperation> SpectralCentroid::Parse() {
data_->hop_length_, data_->pad_, data_->window_); data_->hop_length_, data_->pad_, data_->window_);
} }
Spectrogram::Spectrogram(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window,
float power, bool normalized, bool center, BorderType pad_mode, bool onesided)
: data_(std::make_shared<Data>(n_fft, win_length, hop_length, pad, window, power, normalized, center, pad_mode,
onesided)) {}
std::shared_ptr<TensorOperation> Spectrogram::Parse() {
return std::make_shared<SpectrogramOperation>(data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_,
data_->window_, data_->power_, data_->normalized_, data_->center_,
data_->pad_mode_, data_->onesided_);
}
// TimeMasking Transform Operation. // TimeMasking Transform Operation.
struct TimeMasking::Data { struct TimeMasking::Data {
Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value)

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,19 +13,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/include/dataset/config.h"
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/dataset/config.h"
#include "minddata/dataset/util/log_adapter.h" #include "minddata/dataset/util/log_adapter.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Config operations for setting and getting the configuration. // Config operations for setting and getting the configuration.
namespace config { namespace config {
std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager();
// Function to set the seed to be used in any random generator // Function to set the seed to be used in any random generator
@ -102,7 +100,6 @@ bool load(const std::vector<char> &file) {
} }
return true; return true;
} }
} // namespace config } // namespace config
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -121,41 +121,49 @@ Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::ve
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint8_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint8_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int16_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int16_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint16_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint16_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int32_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int32_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint32_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint32_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int64_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int64_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint64_t &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint64_t &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file)); return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
} }
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const float &value, Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const float &value,
const std::vector<char> &out_file) { const std::vector<char> &out_file) {
auto jh = JsonHelper(); auto jh = JsonHelper();
@ -179,6 +187,5 @@ size_t DataHelper::DumpData(const unsigned char *tensor_addr, const size_t &tens
auto jh = JsonHelper(); auto jh = JsonHelper();
return jh.DumpData(tensor_addr, tensor_size, addr, buffer_size); return jh.DumpData(tensor_addr, tensor_size, addr, buffer_size);
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,33 +15,29 @@
*/ */
#include "minddata/dataset/include/dataset/datasets.h" #include "minddata/dataset/include/dataset/datasets.h"
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/type_id.h"
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/runtime_context.h" #include "minddata/dataset/engine/runtime_context.h"
#include "minddata/dataset/include/dataset/constants.h" #include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/iterator.h" #include "minddata/dataset/include/dataset/iterator.h"
#include "minddata/dataset/include/dataset/samplers.h" #include "minddata/dataset/include/dataset/samplers.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/core/type_id.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
#include "minddata/dataset/kernels/c_func_op.h" #include "minddata/dataset/kernels/c_func_op.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#endif
#ifndef ENABLE_ANDROID
#include "minddata/dataset/text/sentence_piece_vocab.h" #include "minddata/dataset/text/sentence_piece_vocab.h"
#include "minddata/dataset/text/vocab.h" #include "minddata/dataset/text/vocab.h"
#endif #endif
@ -49,6 +45,7 @@
// Sampler headers (in alphabetical order) // Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
// IR dataset node
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
// IR non-leaf nodes // IR non-leaf nodes
@ -60,17 +57,13 @@
#include "minddata/dataset/engine/ir/datasetops/concat_node.h" #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h" #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#endif #endif
#include "minddata/dataset/engine/ir/datasetops/map_node.h" #include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h" #include "minddata/dataset/engine/ir/datasetops/project_node.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/rename_node.h" #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#endif #endif
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h" #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h" #include "minddata/dataset/engine/ir/datasetops/take_node.h"
@ -78,22 +71,15 @@
#include "minddata/dataset/engine/ir/datasetops/zip_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#endif #endif
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"
// IR leaf nodes // IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h" #include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h" #include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
// IR leaf nodes disabled for android
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h" #include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h" #include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h" #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
@ -114,6 +100,9 @@
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h" #include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h" #include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h" #include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h" #include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
@ -139,7 +128,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// convert MSTensorVec to DE TensorRow, return empty if fails // convert MSTensorVec to DE TensorRow, return empty if fails
TensorRow VecToRow(const MSTensorVec &v) { TensorRow VecToRow(const MSTensorVec &v) {
TensorRow row; TensorRow row;
@ -160,19 +148,20 @@ TensorRow VecToRow(const MSTensorVec &v) {
MSTensorVec RowToVec(const TensorRow &v) { MSTensorVec RowToVec(const TensorRow &v) {
MSTensorVec rv; MSTensorVec rv;
rv.reserve(v.size()); rv.reserve(v.size());
std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> MSTensor { std::transform(v.begin(), v.end(), std::back_inserter(rv), [](const std::shared_ptr<Tensor> &t) -> MSTensor {
return mindspore::MSTensor(std::make_shared<DETensor>(t)); return mindspore::MSTensor(std::make_shared<DETensor>(t));
}); });
return rv; return rv;
} }
// Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper // Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper
TensorRow FuncPtrConverter(std::function<MSTensorVec(MSTensorVec)> func, TensorRow in_row) { TensorRow FuncPtrConverter(const std::function<MSTensorVec(MSTensorVec)> &func, const TensorRow &in_row) {
return VecToRow(func(RowToVec(in_row))); return VecToRow(func(RowToVec(in_row)));
} }
// Function to create the iterator, which will build and launch the execution tree. // Function to create the iterator, which will build and launch the execution tree.
std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs) { std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(const std::vector<std::vector<char>> &columns,
int32_t num_epochs) {
std::shared_ptr<Iterator> iter; std::shared_ptr<Iterator> iter;
try { try {
auto ds = shared_from_this(); auto ds = shared_from_this();
@ -198,7 +187,7 @@ std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(std::vector<std::vector<
} }
// Function to create the iterator, which will build and launch the execution tree. // Function to create the iterator, which will build and launch the execution tree.
std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(std::vector<std::vector<char>> columns) { std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(const std::vector<std::vector<char>> &columns) {
// The specified columns will be selected from the dataset and passed down the pipeline // The specified columns will be selected from the dataset and passed down the pipeline
// in the order specified, other columns will be discarded. // in the order specified, other columns will be discarded.
// This code is not in a try/catch block because there is no execution tree class that will be created. // This code is not in a try/catch block because there is no execution tree class that will be created.
@ -424,7 +413,7 @@ std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file) {
// (In alphabetical order) // (In alphabetical order)
// Function to create a Batch dataset // Function to create a Batch dataset
BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) { BatchDataset::BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder) {
// Default values // Default values
auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder); auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -433,9 +422,9 @@ BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, b
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Function to create a BucketBatchByLength dataset // Function to create a BucketBatchByLength dataset
BucketBatchByLengthDataset::BucketBatchByLengthDataset( BucketBatchByLengthDataset::BucketBatchByLengthDataset(
std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &column_names, const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<MSTensorVec(MSTensorVec)> element_length_function, const std::function<MSTensorVec(MSTensorVec)> &element_length_function,
const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary, const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) { bool drop_remainder) {
std::shared_ptr<TensorOp> c_func = nullptr; std::shared_ptr<TensorOp> c_func = nullptr;
@ -467,7 +456,7 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(
ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets; std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets), (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { [](const std::shared_ptr<Dataset> &dataset) -> std::shared_ptr<DatasetNode> {
return (dataset != nullptr) ? dataset->IRNode() : nullptr; return (dataset != nullptr) ? dataset->IRNode() : nullptr;
}); });
@ -476,7 +465,8 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTensorVec(MSTensorVec)> predicate, FilterDataset::FilterDataset(const std::shared_ptr<Dataset> &input,
const std::function<MSTensorVec(MSTensorVec)> &predicate,
const std::vector<std::vector<char>> &input_columns) { const std::vector<std::vector<char>> &input_columns) {
std::shared_ptr<TensorOp> c_func = nullptr; std::shared_ptr<TensorOp> c_func = nullptr;
if (predicate) { if (predicate) {
@ -488,11 +478,13 @@ FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTen
} }
#endif #endif
MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, MapDataset::MapDataset(const std::shared_ptr<Dataset> &input,
const std::vector<std::shared_ptr<TensorOperation>> &operations,
const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &input_columns,
const std::vector<std::vector<char>> &output_columns, const std::vector<std::vector<char>> &output_columns,
const std::vector<std::vector<char>> &project_columns, const std::vector<std::vector<char>> &project_columns,
const std::shared_ptr<DatasetCache> &cache, std::vector<std::shared_ptr<DSCallback>> callbacks) { const std::shared_ptr<DatasetCache> &cache,
const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns), auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns),
VectorCharToString(output_columns), VectorCharToString(project_columns), cache, VectorCharToString(output_columns), VectorCharToString(project_columns), cache,
callbacks); callbacks);
@ -500,13 +492,14 @@ MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_p
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns) { ProjectDataset::ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns) {
auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns)); auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns));
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns, RenameDataset::RenameDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &input_columns,
const std::vector<std::vector<char>> &output_columns) { const std::vector<std::vector<char>> &output_columns) {
auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns), auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns),
VectorCharToString(output_columns)); VectorCharToString(output_columns));
@ -515,13 +508,13 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s
} }
#endif #endif
RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) { RepeatDataset::RepeatDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
auto ds = std::make_shared<RepeatNode>(input->IRNode(), count); auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) { ShuffleDataset::ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size) {
// Pass in reshuffle_each_epoch with true // Pass in reshuffle_each_epoch with true
auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true); auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
@ -529,13 +522,13 @@ ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_si
} }
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) { SkipDataset::SkipDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
auto ds = std::make_shared<SkipNode>(input->IRNode(), count); auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) { TakeDataset::TakeDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
auto ds = std::make_shared<TakeNode>(input->IRNode(), count); auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -544,7 +537,7 @@ TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets; std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets), (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { [](const std::shared_ptr<Dataset> &dataset) -> std::shared_ptr<DatasetNode> {
return (dataset != nullptr) ? dataset->IRNode() : nullptr; return (dataset != nullptr) ? dataset->IRNode() : nullptr;
}); });
auto ds = std::make_shared<ZipNode>(all_datasets); auto ds = std::make_shared<ZipNode>(all_datasets);
@ -552,6 +545,7 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
#endif #endif
int64_t Dataset::GetBatchSize() { int64_t Dataset::GetBatchSize() {
int64_t batch_size = -1; int64_t batch_size = -1;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
@ -737,10 +731,11 @@ Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vect
} }
Status SchemaObj::schema_to_json(nlohmann::json *out_json) { Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json);
nlohmann::json json_file; nlohmann::json json_file;
json_file["columns"] = data_->columns_; json_file["columns"] = data_->columns_;
std::string str_dataset_type_(data_->dataset_type_); std::string str_dataset_type_(data_->dataset_type_);
if (str_dataset_type_ != "") { if (!str_dataset_type_.empty()) {
json_file["datasetType"] = str_dataset_type_; json_file["datasetType"] = str_dataset_type_;
} }
@ -751,13 +746,13 @@ Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
return Status::OK(); return Status::OK();
} }
const std::vector<char> SchemaObj::to_json_char() { std::vector<char> SchemaObj::to_json_char() {
nlohmann::json json_file; nlohmann::json json_file;
this->schema_to_json(&json_file); this->schema_to_json(&json_file);
return StringToChar(json_file.dump(2)); return StringToChar(json_file.dump(2));
} }
void SchemaObj::set_dataset_type(std::string dataset_type) { data_->dataset_type_ = dataset_type.data(); } void SchemaObj::set_dataset_type(const std::string &dataset_type) { data_->dataset_type_ = dataset_type; }
void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; } void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; }
@ -817,7 +812,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) {
for (const auto &it_child : json_obj.items()) { for (const auto &it_child : json_obj.items()) {
if (it_child.key() == "datasetType") { if (it_child.key() == "datasetType") {
std::string str_dataset_type_ = it_child.value(); std::string str_dataset_type_ = it_child.value();
data_->dataset_type_ = str_dataset_type_.data(); data_->dataset_type_ = str_dataset_type_;
} else if (it_child.key() == "numRows") { } else if (it_child.key() == "numRows") {
data_->num_rows_ = it_child.value(); data_->num_rows_ = it_child.value();
} else if (it_child.key() == "columns") { } else if (it_child.key() == "columns") {
@ -867,10 +862,10 @@ Status SchemaObj::ParseColumnStringCharIF(const std::vector<char> &json_string)
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill, std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
std::optional<std::vector<char>> hostname, const std::optional<std::vector<char>> &hostname,
std::optional<int32_t> port, const std::optional<int32_t> &port,
std::optional<int32_t> num_connections, const std::optional<int32_t> &num_connections,
std::optional<int32_t> prefetch_sz) { const std::optional<int32_t> &prefetch_sz) {
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz); auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
return cache; return cache;
} }
@ -901,9 +896,10 @@ AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vect
VectorCharToString(column_names), decode, sampler_obj, cache); VectorCharToString(column_names), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema, AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
const std::vector<std::vector<char>> &column_names, bool decode, const std::vector<std::vector<char>> &column_names, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) { const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema), auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
VectorCharToString(column_names), decode, sampler_obj, cache); VectorCharToString(column_names), decode, sampler_obj, cache);
@ -935,7 +931,7 @@ Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool
} }
Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode, Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache); auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
@ -951,6 +947,7 @@ CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::ve
SetCharToString(extensions), cache); SetCharToString(extensions), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions, const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
@ -959,8 +956,9 @@ CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::ve
SetCharToString(extensions), cache); SetCharToString(extensions), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, bool decode, const std::reference_wrapper<Sampler> &sampler, bool decode,
const std::set<std::vector<char>> &extensions, const std::set<std::vector<char>> &extensions,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
@ -975,14 +973,16 @@ Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) { const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
@ -995,14 +995,16 @@ Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) { const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
@ -1030,7 +1032,7 @@ CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const
CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode, const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode), auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
@ -1054,6 +1056,7 @@ CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector
decode, sampler_obj, cache, extra_metadata); decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const Sampler *sampler, const std::vector<char> &task, const bool &decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) { const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
@ -1062,9 +1065,10 @@ CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector
decode, sampler_obj, cache, extra_metadata); decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const std::vector<char> &task, const bool &decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache, const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
const bool &extra_metadata) { const bool &extra_metadata) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task), auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
@ -1118,7 +1122,7 @@ DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vect
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &downgrade, int32_t scale, bool decode, const std::vector<char> &downgrade, int32_t scale, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) { const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale, auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
decode, sampler_obj, cache); decode, sampler_obj, cache);
@ -1144,7 +1148,7 @@ EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::ve
} }
EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler, const std::vector<char> &usage, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage), auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
@ -1175,7 +1179,7 @@ FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t
} }
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
int32_t base_seed, const std::reference_wrapper<Sampler> sampler, int32_t base_seed, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache); auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
@ -1198,7 +1202,7 @@ FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, c
} }
FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
@ -1223,7 +1227,7 @@ FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::ve
} }
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
bool decode, const std::reference_wrapper<Sampler> sampler, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = auto ds =
@ -1261,7 +1265,7 @@ ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, boo
} }
ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::set<std::vector<char>> &extensions, const std::set<std::vector<char>> &extensions,
const std::map<std::vector<char>, int32_t> &class_indexing, const std::map<std::vector<char>, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
@ -1292,7 +1296,7 @@ IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector
} }
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) { const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
// Create logical representation of IMDBDataset. // Create logical representation of IMDBDataset.
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
@ -1335,7 +1339,7 @@ KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::ve
} }
KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
@ -1356,7 +1360,7 @@ LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const Sam
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler, LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache); auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
@ -1372,6 +1376,7 @@ ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const st
MapCharToString(class_indexing), decode, cache); MapCharToString(class_indexing), decode, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage, ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing, const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
bool decode, const std::shared_ptr<DatasetCache> &cache) { bool decode, const std::shared_ptr<DatasetCache> &cache) {
@ -1380,8 +1385,9 @@ ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const st
MapCharToString(class_indexing), decode, cache); MapCharToString(class_indexing), decode, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage, ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode, const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
@ -1421,7 +1427,7 @@ MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file, MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
const std::vector<std::vector<char>> &columns_list, const std::vector<std::vector<char>> &columns_list,
const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample, const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
int64_t num_padded, ShuffleMode shuffle_mode, int64_t num_padded, ShuffleMode shuffle_mode,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
@ -1468,7 +1474,7 @@ MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_f
MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files, MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
const std::vector<std::vector<char>> &columns_list, const std::vector<std::vector<char>> &columns_list,
const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample, const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
int64_t num_padded, ShuffleMode shuffle_mode, int64_t num_padded, ShuffleMode shuffle_mode,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
@ -1488,14 +1494,16 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler, MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) { const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -1529,7 +1537,7 @@ PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const s
} }
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler, const std::vector<char> &usage, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage), auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
@ -1556,7 +1564,7 @@ Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const s
} }
Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const bool small, const bool decode, const std::reference_wrapper<Sampler> sampler, const bool small, const bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = auto ds =
@ -1579,7 +1587,7 @@ QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::ve
} }
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat, QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache); auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
@ -1650,7 +1658,7 @@ STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vect
} }
STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) { const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -1681,6 +1689,7 @@ VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<c
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata); MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing, const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache, bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache,
@ -1690,9 +1699,10 @@ VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<c
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata); MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing, const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
bool decode, const std::reference_wrapper<Sampler> sampler, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) { const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage), auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
@ -1710,13 +1720,14 @@ WikiTextDataset::WikiTextDataset(const std::vector<char> &dataset_dir, const std
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
const std::vector<std::vector<char>> &columns_list, const std::vector<std::vector<char>> &columns_list,
std::shared_ptr<DatasetCache> cache) { const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache); auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path, RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
const std::vector<std::vector<char>> &columns_list, const std::vector<std::vector<char>> &columns_list,
std::shared_ptr<DatasetCache> cache) { const std::shared_ptr<DatasetCache> &cache) {
auto ds = auto ds =
std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache); std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -1736,8 +1747,8 @@ SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler, SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode,
const std::shared_ptr<DatasetCache> &cache) { const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache); auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -1767,7 +1778,7 @@ SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_di
} }
SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
@ -1777,16 +1788,18 @@ SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_di
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema, TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
const std::vector<std::vector<char>> &columns_list, int64_t num_samples, const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
std::shared_ptr<DatasetCache> cache) { const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema), auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema),
VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id, VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id,
shard_equal_rows, cache); shard_equal_rows, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, std::shared_ptr<SchemaObj> schema,
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files,
const std::shared_ptr<SchemaObj> &schema,
const std::vector<std::vector<char>> &columns_list, int64_t num_samples, const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
std::shared_ptr<DatasetCache> cache) { const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list), auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list),
num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache); num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -1816,7 +1829,7 @@ WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const s
} }
WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode, WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache); auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
@ -1853,13 +1866,12 @@ YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler, YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) { const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache); auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
#endif #endif
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -38,8 +38,8 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
using json = nlohmann::json; using json = nlohmann::json;
struct Execute::ExtraInfo { struct Execute::ExtraInfo {
std::multimap<std::string, std::vector<uint32_t>> aipp_cfg_; std::multimap<std::string, std::vector<uint32_t>> aipp_cfg_;
bool init_with_shared_ptr_ = true; // Initial execute object with shared_ptr as default bool init_with_shared_ptr_ = true; // Initial execute object with shared_ptr as default
@ -70,14 +70,14 @@ Status Execute::InitResource(MapTargetDevice device_type, uint32_t device_id) {
return Status::OK(); return Status::OK();
} }
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(const std::shared_ptr<TensorOperation> &op, MapTargetDevice device_type, uint32_t device_id) {
ops_.emplace_back(std::move(op)); ops_.emplace_back(op);
device_type_ = device_type; device_type_ = device_type;
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
(void)InitResource(device_type, device_id); (void)InitResource(device_type, device_id);
} }
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(const std::shared_ptr<TensorTransform> &op, MapTargetDevice device_type, uint32_t device_id) {
// Initialize the op and other context // Initialize the op and other context
transforms_.emplace_back(op); transforms_.emplace_back(op);
@ -86,7 +86,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_typ
(void)InitResource(device_type, device_id); (void)InitResource(device_type, device_id);
} }
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(const std::reference_wrapper<TensorTransform> &op, MapTargetDevice device_type, uint32_t device_id) {
// Initialize the transforms_ and other context // Initialize the transforms_ and other context
std::shared_ptr<TensorOperation> operation = op.get().Parse(); std::shared_ptr<TensorOperation> operation = op.get().Parse();
ops_.emplace_back(std::move(operation)); ops_.emplace_back(std::move(operation));
@ -108,13 +108,15 @@ Execute::Execute(TensorTransform *op, MapTargetDevice device_type, uint32_t devi
(void)InitResource(device_type, device_id); (void)InitResource(device_type, device_id);
} }
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id) Execute::Execute(const std::vector<std::shared_ptr<TensorOperation>> &ops, MapTargetDevice device_type,
: ops_(std::move(ops)), device_type_(device_type) { uint32_t device_id)
: ops_(ops), device_type_(device_type) {
info_ = std::make_shared<ExtraInfo>(); info_ = std::make_shared<ExtraInfo>();
(void)InitResource(device_type, device_id); (void)InitResource(device_type, device_id);
} }
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) { Execute::Execute(const std::vector<std::shared_ptr<TensorTransform>> &ops, MapTargetDevice device_type,
uint32_t device_id) {
// Initialize the transforms_ and other context // Initialize the transforms_ and other context
transforms_ = ops; transforms_ = ops;
@ -123,7 +125,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
(void)InitResource(device_type, device_id); (void)InitResource(device_type, device_id);
} }
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type, Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> &ops, MapTargetDevice device_type,
uint32_t device_id) { uint32_t device_id) {
// Initialize the transforms_ and other context // Initialize the transforms_ and other context
if (device_type == MapTargetDevice::kCpu) { if (device_type == MapTargetDevice::kCpu) {
@ -174,7 +176,6 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
// Validate input tensor // Validate input tensor
RETURN_UNEXPECTED_IF_NULL(output); RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data."); CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'."); CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
// Parse TensorTransform transforms_ into TensorOperation ops_ // Parse TensorTransform transforms_ into TensorOperation ops_
@ -269,7 +270,6 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
// Validate input tensor // Validate input tensor
RETURN_UNEXPECTED_IF_NULL(output_tensor_list); RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid."); CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr.");
output_tensor_list->clear(); output_tensor_list->clear();
for (auto &tensor : input_tensor_list) { for (auto &tensor : input_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data."); CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data.");
@ -374,16 +374,16 @@ std::vector<uint32_t> AippSizeFilter(const std::vector<uint32_t> &resize_para, c
std::vector<uint32_t> aipp_size; std::vector<uint32_t> aipp_size;
// Special condition where (no Crop and no Resize) or (no Crop and resize with fixed ratio) will lead to dynamic input // Special condition where (no Crop and no Resize) or (no Crop and resize with fixed ratio) will lead to dynamic input
if ((resize_para.size() == 0 || resize_para.size() == 1) && crop_para.size() == 0) { if ((resize_para.empty() || resize_para.size() == 1) && crop_para.empty()) {
aipp_size = {0, 0}; aipp_size = {0, 0};
MS_LOG(WARNING) << "Dynamic input shape is not supported, incomplete aipp config file will be generated. Please " MS_LOG(WARNING) << "Dynamic input shape is not supported, incomplete aipp config file will be generated. Please "
"checkout your TensorTransform input, both src_image_size_h and src_image_size will be 0."; "checkout your TensorTransform input, both src_image_size_h and src_image_size will be 0.";
return aipp_size; return aipp_size;
} }
if (resize_para.size() == 0) { // If only Crop operator exists if (resize_para.empty()) { // If only Crop operator exists
aipp_size = crop_para; aipp_size = crop_para;
} else if (crop_para.size() == 0) { // If only Resize operator with 2 parameters exists } else if (crop_para.empty()) { // If only Resize operator with 2 parameters exists
aipp_size = resize_para; aipp_size = resize_para;
} else { // If both of them exist } else { // If both of them exist
if (resize_para.size() == 1) { if (resize_para.size() == 1) {
@ -437,6 +437,7 @@ std::vector<float> AippStdFilter(const std::vector<uint32_t> &normalize_para) {
Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, const std::vector<uint32_t> &aipp_size, Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, const std::vector<uint32_t> &aipp_size,
const std::vector<uint32_t> &aipp_mean, const std::vector<float> &aipp_std) { const std::vector<uint32_t> &aipp_mean, const std::vector<float> &aipp_std) {
RETURN_UNEXPECTED_IF_NULL(aipp_options);
// Several aipp config parameters // Several aipp config parameters
aipp_options->insert(std::make_pair("related_input_rank", "0")); aipp_options->insert(std::make_pair("related_input_rank", "0"));
aipp_options->insert(std::make_pair("src_image_size_w", std::to_string(aipp_size[1]))); aipp_options->insert(std::make_pair("src_image_size_w", std::to_string(aipp_size[1])));
@ -446,7 +447,7 @@ Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, cons
aipp_options->insert(std::make_pair("aipp_mode", "static")); aipp_options->insert(std::make_pair("aipp_mode", "static"));
aipp_options->insert(std::make_pair("csc_switch", "true")); aipp_options->insert(std::make_pair("csc_switch", "true"));
aipp_options->insert(std::make_pair("rbuv_swap_switch", "false")); aipp_options->insert(std::make_pair("rbuv_swap_switch", "false"));
// Y = AX + b, this part is A // Y = AX + b, this part is A
std::vector<int32_t> color_space_matrix = {256, 0, 359, 256, -88, -183, 256, 454, 0}; std::vector<int32_t> color_space_matrix = {256, 0, 359, 256, -88, -183, 256, 454, 0};
int count = 0; int count = 0;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -489,10 +490,7 @@ std::string Execute::AippCfgGenerator() {
#ifdef ENABLE_ACL #ifdef ENABLE_ACL
if (info_->init_with_shared_ptr_) { if (info_->init_with_shared_ptr_) {
auto rc = ParseTransforms(); auto rc = ParseTransforms();
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(rc, "");
MS_LOG(ERROR) << "Parse transforms failed, error msg is " << rc;
return "";
}
info_->init_with_shared_ptr_ = false; info_->init_with_shared_ptr_ = false;
} }
std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators
@ -574,6 +572,7 @@ std::string Execute::AippCfgGenerator() {
auto rc = AippInfoCollection(&aipp_options, aipp_size, aipp_mean, aipp_std); auto rc = AippInfoCollection(&aipp_options, aipp_size, aipp_mean, aipp_std);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "Aipp information initialization failed, error msg is " << rc; MS_LOG(ERROR) << "Aipp information initialization failed, error msg is " << rc;
outfile.close();
return ""; return "";
} }
@ -594,7 +593,7 @@ std::string Execute::AippCfgGenerator() {
return config_location; return config_location;
} }
bool IsEmptyPtr(std::shared_ptr<TensorTransform> api_ptr) { return api_ptr == nullptr; } bool IsEmptyPtr(const std::shared_ptr<TensorTransform> &api_ptr) { return api_ptr == nullptr; }
Status Execute::ParseTransforms() { Status Execute::ParseTransforms() {
auto iter = std::find_if(transforms_.begin(), transforms_.end(), IsEmptyPtr); auto iter = std::find_if(transforms_.begin(), transforms_.end(), IsEmptyPtr);
@ -606,7 +605,7 @@ Status Execute::ParseTransforms() {
if (device_type_ == MapTargetDevice::kCpu) { if (device_type_ == MapTargetDevice::kCpu) {
(void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(ops_), (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(ops_),
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> { [](const std::shared_ptr<TensorTransform> &operation) -> std::shared_ptr<TensorOperation> {
return operation->Parse(); return operation->Parse();
}); });
} else if (device_type_ == MapTargetDevice::kAscend310) { } else if (device_type_ == MapTargetDevice::kAscend310) {
@ -645,10 +644,11 @@ Status Execute::DeviceMemoryRelease() {
Status Execute::Run(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph, Status Execute::Run(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs) { const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs) {
RETURN_UNEXPECTED_IF_NULL(outputs);
std::vector<MSTensor> transform_inputs = inputs; std::vector<MSTensor> transform_inputs = inputs;
std::vector<MSTensor> transform_outputs; std::vector<MSTensor> transform_outputs;
if (!data_graph.empty()) { if (!data_graph.empty()) {
for (auto exes : data_graph) { for (const auto &exes : data_graph) {
CHECK_FAIL_RETURN_UNEXPECTED(exes != nullptr, "Given execute object is null."); CHECK_FAIL_RETURN_UNEXPECTED(exes != nullptr, "Given execute object is null.");
Status ret = exes->operator()(transform_inputs, &transform_outputs); Status ret = exes->operator()(transform_inputs, &transform_outputs);
if (ret != kSuccess) { if (ret != kSuccess) {
@ -675,9 +675,11 @@ extern "C" {
void ExecuteRun_C(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph, void ExecuteRun_C(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs, Status *s) { std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs, Status *s) {
Status ret = Execute::Run(data_graph, inputs, outputs); Status ret = Execute::Run(data_graph, inputs, outputs);
if (s == nullptr) {
return;
}
*s = Status(ret); *s = Status(ret);
} }
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/include/dataset/iterator.h" #include "minddata/dataset/include/dataset/iterator.h"
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h" #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/runtime_context.h" #include "minddata/dataset/engine/runtime_context.h"
@ -75,7 +76,7 @@ void Iterator::Stop() {
} }
// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) { Status Iterator::BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs) {
RETURN_UNEXPECTED_IF_NULL(ds); RETURN_UNEXPECTED_IF_NULL(ds);
runtime_context_ = std::make_unique<NativeRuntimeContext>(); runtime_context_ = std::make_unique<NativeRuntimeContext>();
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed."); CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
@ -105,7 +106,7 @@ Status PullIterator::GetRows(int32_t num_rows, std::vector<MSTensorVec> *const r
} }
MSTensorVec ms_row = {}; MSTensorVec ms_row = {};
for (auto de_tensor : md_row) { for (const auto &de_tensor : md_row) {
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data"); CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
ms_row.push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor))); ms_row.push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
} }
@ -125,7 +126,7 @@ Status PullIterator::GetNextRow(MSTensorVec *const row) {
return rc; return rc;
} }
for (auto de_tensor : md_row) { for (const auto &de_tensor : md_row) {
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data"); CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
row->push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor))); row->push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
} }
@ -135,7 +136,7 @@ Status PullIterator::GetNextRow(MSTensorVec *const row) {
// Function to build and launch the execution tree. This function kicks off a different type of consumer // Function to build and launch the execution tree. This function kicks off a different type of consumer
// for the tree, the reason why this is the case is due to the fact that PullBasedIterator does not need // for the tree, the reason why this is the case is due to the fact that PullBasedIterator does not need
// to instantiate threads for each op. As such, the call to the consumer will by pass the execution tree. // to instantiate threads for each op. As such, the call to the consumer will by pass the execution tree.
Status PullIterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { Status PullIterator::BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds) {
if (pull_consumer_ == nullptr) { if (pull_consumer_ == nullptr) {
pull_consumer_ = std::make_unique<PullBasedIteratorConsumer>(); pull_consumer_ = std::make_unique<PullBasedIteratorConsumer>();
} }
@ -167,7 +168,7 @@ Iterator::_Iterator &Iterator::_Iterator::operator++() {
cur_row_ = nullptr; cur_row_ = nullptr;
} }
} }
if (cur_row_ && cur_row_->size() == 0) { if (cur_row_ && cur_row_->empty()) {
delete cur_row_; delete cur_row_;
cur_row_ = nullptr; cur_row_ = nullptr;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -60,7 +60,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PYBIND_REGISTER( PYBIND_REGISTER(
AllpassBiquadOperation, 1, ([](const py::module *m) { AllpassBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::AllpassBiquadOperation, TensorOperation, std::shared_ptr<audio::AllpassBiquadOperation>>( (void)py::class_<audio::AllpassBiquadOperation, TensorOperation, std::shared_ptr<audio::AllpassBiquadOperation>>(

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,15 +13,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl_bind.h" #include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/client.h" // DE client #include "minddata/dataset/core/client.h" // DE client
#include "minddata/dataset/util/status.h" #include "minddata/dataset/core/global_context.h"
#include "pybind11/numpy.h"
#include "minddata/dataset/include/dataset/constants.h" #include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -63,12 +65,12 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_enable_autotune", &ConfigManager::enable_autotune) .def("get_enable_autotune", &ConfigManager::enable_autotune)
.def("set_autotune_interval", &ConfigManager::set_autotune_interval) .def("set_autotune_interval", &ConfigManager::set_autotune_interval)
.def("get_autotune_interval", &ConfigManager::autotune_interval) .def("get_autotune_interval", &ConfigManager::autotune_interval)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); .def("load", [](ConfigManager &c, const std::string &s) { THROW_IF_ERROR(c.LoadFile(s)); });
})); }));
PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) { PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) {
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol()) (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
.def(py::init([](py::array arr) { .def(py::init([](const py::array &arr) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out)); THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out));
return out; return out;
@ -107,7 +109,7 @@ PYBIND_REGISTER(DataType, 0, ([](const py::module *m) {
.def(py::init<std::string>()) .def(py::init<std::string>())
.def(py::self == py::self) .def(py::self == py::self)
.def("__str__", &DataType::ToString) .def("__str__", &DataType::ToString)
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); .def("__deepcopy__", [](py::object &t, const py::dict &memo) { return t; });
})); }));
PYBIND_REGISTER(AutoAugmentPolicy, 0, ([](const py::module *m) { PYBIND_REGISTER(AutoAugmentPolicy, 0, ([](const py::module *m) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PYBIND_REGISTER( PYBIND_REGISTER(
Graph, 0, ([](const py::module *m) { Graph, 0, ([](const py::module *m) {
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient") (void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
@ -51,7 +50,7 @@ PYBIND_REGISTER(
return out; return out;
}) })
.def("get_nodes_from_edges", .def("get_nodes_from_edges",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) { [](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &edge_list) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out; return out;
@ -70,31 +69,34 @@ PYBIND_REGISTER(
return out; return out;
}) })
.def("get_sampled_neighbors", .def("get_sampled_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, [](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list,
std::vector<gnn::NodeType> neighbor_types, SamplingStrategy strategy) { const std::vector<gnn::NodeIdType> &neighbor_nums, const std::vector<gnn::NodeType> &neighbor_types,
SamplingStrategy strategy) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out)); THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out));
return out; return out;
}) })
.def("get_neg_sampled_neighbors", .def("get_neg_sampled_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, [](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) { gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out; return out;
}) })
.def("get_node_feature", .def(
[](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { "get_node_feature",
TensorRow out; [](gnn::GraphData &g, const std::shared_ptr<Tensor> &node_list, std::vector<gnn::FeatureType> feature_types) {
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); TensorRow out;
return out.getRow(); THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
}) return out.getRow();
.def("get_edge_feature", })
[](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { .def(
TensorRow out; "get_edge_feature",
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); [](gnn::GraphData &g, const std::shared_ptr<Tensor> &edge_list, std::vector<gnn::FeatureType> feature_types) {
return out.getRow(); TensorRow out;
}) THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
return out.getRow();
})
.def("graph_info", .def("graph_info",
[](gnn::GraphData &g) { [](gnn::GraphData &g) {
py::dict out; py::dict out;
@ -102,8 +104,9 @@ PYBIND_REGISTER(
return out; return out;
}) })
.def("random_walk", .def("random_walk",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, [](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) { const std::vector<gnn::NodeType> &meta_path, float step_home_param, float step_away_param,
gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out; return out;
@ -137,6 +140,5 @@ PYBIND_REGISTER(OutputFormat, 0, ([](const py::module *m) {
.value("DE_FORMAT_CSR", OutputFormat::kCsr) .value("DE_FORMAT_CSR", OutputFormat::kCsr)
.export_values(); .export_values();
})); }));
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,9 +18,12 @@
#include "minddata/dataset/api/python/pybind_conversion.h" #include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h" #include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/include/dataset/constants.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/engine/serdes.h" #include "minddata/dataset/engine/serdes.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/text/sentence_piece_vocab.h" #include "minddata/dataset/text/sentence_piece_vocab.h"
#include "minddata/dataset/util/path.h"
// IR non-leaf nodes // IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h" #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
@ -40,37 +43,33 @@
// IR non-leaf nodes - for android // IR non-leaf nodes - for android
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h" #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif #endif
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/util/path.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset") (void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
.def("set_num_workers", .def("set_num_workers",
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) { [](const std::shared_ptr<DatasetNode> &self, std::optional<int32_t> num_workers) {
return num_workers ? self->SetNumWorkers(*num_workers) : self; return num_workers ? self->SetNumWorkers(*num_workers) : self;
}) })
.def("set_cache_client", .def("set_cache_client",
[](std::shared_ptr<DatasetNode> self, std::shared_ptr<CacheClient> cc) { [](const std::shared_ptr<DatasetNode> &self, std::shared_ptr<CacheClient> cc) {
return self->SetDatasetCache(toDatasetCache(std::move(cc))); return self->SetDatasetCache(toDatasetCache(std::move(cc)));
}) })
.def( .def(
"Zip", "Zip",
[](std::shared_ptr<DatasetNode> self, py::list datasets) { [](const std::shared_ptr<DatasetNode> &self, const py::list &datasets) {
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets))); auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
THROW_IF_ERROR(zip->ValidateParams()); THROW_IF_ERROR(zip->ValidateParams());
return zip; return zip;
}, },
py::arg("datasets")) py::arg("datasets"))
.def("to_json", .def("to_json",
[](std::shared_ptr<DatasetNode> self, const std::string &json_filepath) { [](const std::shared_ptr<DatasetNode> &self, const std::string &json_filepath) {
nlohmann::json args; nlohmann::json args;
THROW_IF_ERROR(Serdes::SaveToJSON(self, json_filepath, &args)); THROW_IF_ERROR(Serdes::SaveToJSON(self, json_filepath, &args));
return args.dump(); return args.dump();
@ -92,35 +91,34 @@ PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
// PYBIND FOR NON-LEAF NODES // PYBIND FOR NON-LEAF NODES
// (In alphabetical order) // (In alphabetical order)
PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) { PYBIND_REGISTER(
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode", BatchNode, 2, ([](const py::module *m) {
"to create a BatchNode") (void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode", "to create a BatchNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder, .def(py::init([](const std::shared_ptr<DatasetNode> &self, int32_t batch_size, bool drop_remainder, bool pad,
bool pad, py::list in_col_names, py::list out_col_names, py::list col_order, const py::list &in_col_names, const py::list &out_col_names, const py::list &col_order,
py::object size_obj, py::object map_obj, py::dict pad_info) { const py::object &size_obj, const py::object &map_obj, const py::dict &pad_info) {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info; std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
if (pad) { if (pad) {
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info)); THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
} }
py::function size_func = py::function size_func =
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function(); py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
py::function map_func = py::function map_func = py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function(); auto batch = std::make_shared<BatchNode>(self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
auto batch = std::make_shared<BatchNode>( toStringVector(out_col_names), toStringVector(col_order), size_func,
self, batch_size, drop_remainder, pad, toStringVector(in_col_names), map_func, c_pad_info);
toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info); THROW_IF_ERROR(batch->ValidateParams());
THROW_IF_ERROR(batch->ValidateParams()); return batch;
return batch; }));
})); }));
}));
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) { PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>( (void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode") *m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
.def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names, .def(py::init([](const std::shared_ptr<DatasetNode> &dataset, const py::list &column_names,
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes, const std::vector<int32_t> &bucket_boundaries,
py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary, const std::vector<int32_t> &bucket_batch_sizes, py::object element_length_function,
bool drop_remainder) { const py::dict &pad_info, bool pad_to_bucket_boundary, bool drop_remainder) {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info; std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info)); THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
@ -139,22 +137,23 @@ PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) { PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>( (void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode") *m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab, .def(py::init(
const std::vector<std::string> &col_names, int32_t vocab_size, [](const std::shared_ptr<DatasetNode> &self, const std::shared_ptr<SentencePieceVocab> &vocab,
float character_coverage, SentencePieceModel model_type, const std::vector<std::string> &col_names, int32_t vocab_size, float character_coverage,
const std::unordered_map<std::string, std::string> &params) { SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>( auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
self, vocab, col_names, vocab_size, character_coverage, model_type, params); self, vocab, col_names, vocab_size, character_coverage, model_type, params);
THROW_IF_ERROR(build_sentence_vocab->ValidateParams()); THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
return build_sentence_vocab; return build_sentence_vocab;
})); }));
})); }));
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) { PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>( (void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
*m, "BuildVocabNode", "to create a BuildVocabNode") *m, "BuildVocabNode", "to create a BuildVocabNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns, .def(py::init([](const std::shared_ptr<DatasetNode> &self, const std::shared_ptr<Vocab> &vocab,
py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) { const py::list &columns, const py::tuple &freq_range, int64_t top_k,
py::list special_tokens, bool special_first) {
auto build_vocab = auto build_vocab =
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range), std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
top_k, toStringVector(special_tokens), special_first); top_k, toStringVector(special_tokens), special_first);
@ -166,8 +165,8 @@ PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode", (void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
"to create a ConcatNode") "to create a ConcatNode")
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, py::handle sampler, .def(py::init([](const std::vector<std::shared_ptr<DatasetNode>> &datasets, py::handle sampler,
py::list children_flag_and_nums, py::list children_start_end_index) { const py::list &children_flag_and_nums, const py::list &children_start_end_index) {
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler), auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
toPairVector(children_flag_and_nums), toPairVector(children_flag_and_nums),
toPairVector(children_start_end_index)); toPairVector(children_start_end_index));
@ -179,8 +178,8 @@ PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) { PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode", (void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
"to create a FilterNode") "to create a FilterNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate, .def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::object &predicate,
std::vector<std::string> input_columns) { const std::vector<std::string> &input_columns) {
auto filter = auto filter =
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns); std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
THROW_IF_ERROR(filter->ValidateParams()); THROW_IF_ERROR(filter->ValidateParams());
@ -190,8 +189,9 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode") (void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns, .def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::list &operations,
py::list output_columns, py::list project_columns, const py::list &input_columns, const py::list &output_columns,
const py::list &project_columns,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize, std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize,
const ManualOffloadMode offload) { const ManualOffloadMode offload) {
auto map = std::make_shared<MapNode>( auto map = std::make_shared<MapNode>(
@ -206,27 +206,29 @@ PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode", (void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
"to create a ProjectNode") "to create a ProjectNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) { .def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::list &columns) {
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns)); auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
THROW_IF_ERROR(project->ValidateParams()); THROW_IF_ERROR(project->ValidateParams());
return project; return project;
})); }));
})); }));
PYBIND_REGISTER( PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
RenameNode, 2, ([](const py::module *m) { (void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode", "to create a RenameNode") "to create a RenameNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list input_columns, py::list output_columns) { .def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::list &input_columns,
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns), toStringVector(output_columns)); const py::list &output_columns) {
THROW_IF_ERROR(rename->ValidateParams()); auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
return rename; toStringVector(output_columns));
})); THROW_IF_ERROR(rename->ValidateParams());
})); return rename;
}));
}));
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) { PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode", (void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
"to create a RepeatNode") "to create a RepeatNode")
.def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) { .def(py::init([](const std::shared_ptr<DatasetNode> &input, int32_t count) {
auto repeat = std::make_shared<RepeatNode>(input, count); auto repeat = std::make_shared<RepeatNode>(input, count);
THROW_IF_ERROR(repeat->ValidateParams()); THROW_IF_ERROR(repeat->ValidateParams());
return repeat; return repeat;
@ -236,17 +238,18 @@ PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode", (void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
"to create a ShuffleNode") "to create a ShuffleNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) { .def(py::init(
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch); [](const std::shared_ptr<DatasetNode> &self, int32_t shuffle_size, bool reset_every_epoch) {
THROW_IF_ERROR(shuffle->ValidateParams()); auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
return shuffle; THROW_IF_ERROR(shuffle->ValidateParams());
})); return shuffle;
}));
})); }));
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) { PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode", (void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
"to create a SkipNode") "to create a SkipNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) { .def(py::init([](const std::shared_ptr<DatasetNode> &self, int32_t count) {
auto skip = std::make_shared<SkipNode>(self, count); auto skip = std::make_shared<SkipNode>(self, count);
THROW_IF_ERROR(skip->ValidateParams()); THROW_IF_ERROR(skip->ValidateParams());
return skip; return skip;
@ -256,20 +259,20 @@ PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) { PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode", (void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
"to create a SyncWaitNode") "to create a SyncWaitNode")
.def( .def(py::init([](const std::shared_ptr<DatasetNode> &self, const std::string &condition_name,
py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) { py::object callback) {
py::function callback_func = py::function callback_func =
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function(); py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback); auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
THROW_IF_ERROR(sync_wait->ValidateParams()); THROW_IF_ERROR(sync_wait->ValidateParams());
return sync_wait; return sync_wait;
})); }));
})); }));
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) { PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode", (void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
"to create a TakeNode") "to create a TakeNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) { .def(py::init([](const std::shared_ptr<DatasetNode> &self, int32_t count) {
auto take = std::make_shared<TakeNode>(self, count); auto take = std::make_shared<TakeNode>(self, count);
THROW_IF_ERROR(take->ValidateParams()); THROW_IF_ERROR(take->ValidateParams());
return take; return take;
@ -279,9 +282,9 @@ PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) { PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode", (void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
"to create a TransferNode") "to create a TransferNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type, .def(py::init([](const std::shared_ptr<DatasetNode> &self, const std::string &queue_name,
int32_t device_id, bool send_epoch_end, int32_t total_batch, const std::string &device_type, int32_t device_id, bool send_epoch_end,
bool create_data_info_queue) { int32_t total_batch, bool create_data_info_queue) {
auto transfer = std::make_shared<TransferNode>( auto transfer = std::make_shared<TransferNode>(
self, queue_name, device_type, device_id, send_epoch_end, total_batch, create_data_info_queue); self, queue_name, device_type, device_id, send_epoch_end, total_batch, create_data_info_queue);
THROW_IF_ERROR(transfer->ValidateParams()); THROW_IF_ERROR(transfer->ValidateParams());
@ -291,7 +294,7 @@ PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode") (void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) { .def(py::init([](const std::vector<std::shared_ptr<DatasetNode>> &datasets) {
auto zip = std::make_shared<ZipNode>(datasets); auto zip = std::make_shared<ZipNode>(datasets);
THROW_IF_ERROR(zip->ValidateParams()); THROW_IF_ERROR(zip->ValidateParams());
return zip; return zip;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,11 +17,10 @@
#include "minddata/dataset/api/python/pybind_conversion.h" #include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
// IR leaf nodes // IR leaf nodes
@ -29,8 +28,8 @@
#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h" #include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h" #include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h" #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
@ -81,14 +80,13 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// PYBIND FOR LEAF NODES // PYBIND FOR LEAF NODES
// (In alphabetical order) // (In alphabetical order)
PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) { PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
(void)py::class_<AGNewsNode, DatasetNode, std::shared_ptr<AGNewsNode>>(*m, "AGNewsNode", (void)py::class_<AGNewsNode, DatasetNode, std::shared_ptr<AGNewsNode>>(*m, "AGNewsNode",
"to create an AGNewsNode") "to create an AGNewsNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto ag_news = std::make_shared<AGNewsNode>(dataset_dir, num_samples, toShuffleMode(shuffle), auto ag_news = std::make_shared<AGNewsNode>(dataset_dir, num_samples, toShuffleMode(shuffle),
usage, num_shards, shard_id, nullptr); usage, num_shards, shard_id, nullptr);
THROW_IF_ERROR(ag_news->ValidateParams()); THROW_IF_ERROR(ag_news->ValidateParams());
@ -99,8 +97,8 @@ PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(AmazonReviewNode, 2, ([](const py::module *m) { PYBIND_REGISTER(AmazonReviewNode, 2, ([](const py::module *m) {
(void)py::class_<AmazonReviewNode, DatasetNode, std::shared_ptr<AmazonReviewNode>>( (void)py::class_<AmazonReviewNode, DatasetNode, std::shared_ptr<AmazonReviewNode>>(
*m, "AmazonReviewNode", "to create an AmazonReviewNode") *m, "AmazonReviewNode", "to create an AmazonReviewNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<AmazonReviewNode> amazon_review = std::make_shared<AmazonReviewNode>( std::shared_ptr<AmazonReviewNode> amazon_review = std::make_shared<AmazonReviewNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(amazon_review->ValidateParams()); THROW_IF_ERROR(amazon_review->ValidateParams());
@ -111,7 +109,7 @@ PYBIND_REGISTER(AmazonReviewNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(Caltech256Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Caltech256Node, 2, ([](const py::module *m) {
(void)py::class_<Caltech256Node, DatasetNode, std::shared_ptr<Caltech256Node>>( (void)py::class_<Caltech256Node, DatasetNode, std::shared_ptr<Caltech256Node>>(
*m, "Caltech256Node", "to create a Caltech256Node") *m, "Caltech256Node", "to create a Caltech256Node")
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler) {
auto caltech256 = auto caltech256 =
std::make_shared<Caltech256Node>(dataset_dir, decode, toSamplerObj(sampler), nullptr); std::make_shared<Caltech256Node>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(caltech256->ValidateParams()); THROW_IF_ERROR(caltech256->ValidateParams());
@ -122,8 +120,8 @@ PYBIND_REGISTER(Caltech256Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) { PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", (void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
"to create a CelebANode") "to create a CelebANode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode, .def(py::init([](const std::string &dataset_dir, const std::string &usage,
py::list extensions) { const py::handle &sampler, bool decode, const py::list &extensions) {
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode, auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
toStringSet(extensions), nullptr); toStringSet(extensions), nullptr);
THROW_IF_ERROR(celebA->ValidateParams()); THROW_IF_ERROR(celebA->ValidateParams());
@ -134,7 +132,8 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node", (void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
"to create a Cifar10Node") "to create a Cifar10Node")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const std::string &usage,
const py::handle &sampler) {
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr); auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(cifar10->ValidateParams()); THROW_IF_ERROR(cifar10->ValidateParams());
return cifar10; return cifar10;
@ -144,19 +143,21 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node", (void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
"to create a Cifar100Node") "to create a Cifar100Node")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto cifar100 = py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr); auto cifar100 =
THROW_IF_ERROR(cifar100->ValidateParams()); std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return cifar100; THROW_IF_ERROR(cifar100->ValidateParams());
})); return cifar100;
}));
})); }));
PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) { PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
(void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>( (void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>(
*m, "CityscapesNode", "to create a CityscapesNode") *m, "CityscapesNode", "to create a CityscapesNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::string quality_mode, .def(py::init([](const std::string &dataset_dir, const std::string &usage,
std::string task, bool decode, const py::handle &sampler) { const std::string &quality_mode, const std::string &task, bool decode,
const py::handle &sampler) {
auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode, auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode,
toSamplerObj(sampler), nullptr); toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(cityscapes->ValidateParams()); THROW_IF_ERROR(cityscapes->ValidateParams());
@ -167,8 +168,8 @@ PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) { PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", (void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
"to create a CLUENode") "to create a CLUENode")
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, .def(py::init([](const py::list &files, const std::string &task, const std::string &usage,
int32_t shuffle, int32_t num_shards, int32_t shard_id) { int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<CLUENode> clue_node = std::shared_ptr<CLUENode> clue_node =
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
toShuffleMode(shuffle), num_shards, shard_id, nullptr); toShuffleMode(shuffle), num_shards, shard_id, nullptr);
@ -180,8 +181,9 @@ PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) { PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", (void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
"to create a CocoNode") "to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, .def(py::init([](const std::string &dataset_dir, const std::string &annotation_file,
bool decode, const py::handle &sampler, bool extra_metadata) { const std::string &task, bool decode, const py::handle &sampler,
bool extra_metadata) {
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>( std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata); dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
THROW_IF_ERROR(coco->ValidateParams()); THROW_IF_ERROR(coco->ValidateParams());
@ -192,8 +194,8 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) { PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) {
(void)py::class_<CoNLL2000Node, DatasetNode, std::shared_ptr<CoNLL2000Node>>( (void)py::class_<CoNLL2000Node, DatasetNode, std::shared_ptr<CoNLL2000Node>>(
*m, "CoNLL2000Node", "to create a CoNLL2000Node") *m, "CoNLL2000Node", "to create a CoNLL2000Node")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<CoNLL2000Node> conll2000 = std::make_shared<CoNLL2000Node>( std::shared_ptr<CoNLL2000Node> conll2000 = std::make_shared<CoNLL2000Node>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(conll2000->ValidateParams()); THROW_IF_ERROR(conll2000->ValidateParams());
@ -203,9 +205,9 @@ PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode") (void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults, .def(py::init([](const std::vector<std::string> &csv_files, char field_delim,
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle, const py::list &column_defaults, const std::vector<std::string> &column_names,
int32_t num_shards, int32_t shard_id) { int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto csv = auto csv =
std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names, std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
@ -217,8 +219,8 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) { PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) {
(void)py::class_<DBpediaNode, DatasetNode, std::shared_ptr<DBpediaNode>>(*m, "DBpediaNode", (void)py::class_<DBpediaNode, DatasetNode, std::shared_ptr<DBpediaNode>>(*m, "DBpediaNode",
"to create a DBpediaNode") "to create a DBpediaNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto dbpedia = std::make_shared<DBpediaNode>( auto dbpedia = std::make_shared<DBpediaNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(dbpedia->ValidateParams()); THROW_IF_ERROR(dbpedia->ValidateParams());
@ -229,19 +231,21 @@ PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) { PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
(void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode", (void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
"to create a DIV2KNode") "to create a DIV2KNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::string downgrade, int32_t scale, .def(
bool decode, py::handle sampler) { py::init([](const std::string &dataset_dir, const std::string &usage,
auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode, const std::string &downgrade, int32_t scale, bool decode, const py::handle &sampler) {
toSamplerObj(sampler), nullptr); auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
THROW_IF_ERROR(div2k->ValidateParams()); toSamplerObj(sampler), nullptr);
return div2k; THROW_IF_ERROR(div2k->ValidateParams());
})); return div2k;
}));
})); }));
PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) { PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
(void)py::class_<EMnistNode, DatasetNode, std::shared_ptr<EMnistNode>>(*m, "EMnistNode", (void)py::class_<EMnistNode, DatasetNode, std::shared_ptr<EMnistNode>>(*m, "EMnistNode",
"to create an EMnistNode") "to create an EMnistNode")
.def(py::init([](std::string dataset_dir, std::string name, std::string usage, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const std::string &name, const std::string &usage,
const py::handle &sampler) {
auto emnist = auto emnist =
std::make_shared<EMnistNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr); std::make_shared<EMnistNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(emnist->ValidateParams()); THROW_IF_ERROR(emnist->ValidateParams());
@ -252,8 +256,8 @@ PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) { PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) {
(void)py::class_<EnWik9Node, DatasetNode, std::shared_ptr<EnWik9Node>>(*m, "EnWik9Node", (void)py::class_<EnWik9Node, DatasetNode, std::shared_ptr<EnWik9Node>>(*m, "EnWik9Node",
"to create an EnWik9Node") "to create an EnWik9Node")
.def(py::init([](std::string dataset_dir, int32_t num_samples, int32_t shuffle, int32_t num_shards, .def(py::init([](const std::string &dataset_dir, int32_t num_samples, int32_t shuffle,
int32_t shard_id) { int32_t num_shards, int32_t shard_id) {
std::shared_ptr<EnWik9Node> en_wik9 = std::make_shared<EnWik9Node>( std::shared_ptr<EnWik9Node> en_wik9 = std::make_shared<EnWik9Node>(
dataset_dir, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(en_wik9->ValidateParams()); THROW_IF_ERROR(en_wik9->ValidateParams());
@ -264,8 +268,8 @@ PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) { PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
(void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>( (void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>(
*m, "FakeImageNode", "to create a FakeImageNode") *m, "FakeImageNode", "to create a FakeImageNode")
.def(py::init([](int32_t num_images, const std::vector<int32_t> image_size, int32_t num_classes, .def(py::init([](int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
int32_t base_seed, py::handle sampler) { int32_t base_seed, const py::handle &sampler) {
auto fake_image = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, auto fake_image = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed,
toSamplerObj(sampler), nullptr); toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(fake_image->ValidateParams()); THROW_IF_ERROR(fake_image->ValidateParams());
@ -276,39 +280,42 @@ PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(FashionMnistNode, 2, ([](const py::module *m) { PYBIND_REGISTER(FashionMnistNode, 2, ([](const py::module *m) {
(void)py::class_<FashionMnistNode, DatasetNode, std::shared_ptr<FashionMnistNode>>( (void)py::class_<FashionMnistNode, DatasetNode, std::shared_ptr<FashionMnistNode>>(
*m, "FashionMnistNode", "to create a FashionMnistNode") *m, "FashionMnistNode", "to create a FashionMnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto fashion_mnist = py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
std::make_shared<FashionMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); auto fashion_mnist =
THROW_IF_ERROR(fashion_mnist->ValidateParams()); std::make_shared<FashionMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return fashion_mnist; THROW_IF_ERROR(fashion_mnist->ValidateParams());
})); return fashion_mnist;
}));
})); }));
PYBIND_REGISTER( PYBIND_REGISTER(FlickrNode, 2, ([](const py::module *m) {
FlickrNode, 2, ([](const py::module *m) { (void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode",
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode") "to create a FlickrNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, const py::handle &sampler) { .def(py::init([](const std::string &dataset_dir, const std::string &annotation_file, bool decode,
auto flickr = const py::handle &sampler) {
std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, toSamplerObj(sampler), nullptr); auto flickr = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode,
THROW_IF_ERROR(flickr->ValidateParams()); toSamplerObj(sampler), nullptr);
return flickr; THROW_IF_ERROR(flickr->ValidateParams());
})); return flickr;
})); }));
}));
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>( (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
*m, "GeneratorNode", "to create a GeneratorNode") *m, "GeneratorNode", "to create a GeneratorNode")
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names, .def(
const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler, py::init([](const py::function &generator_function, const std::vector<std::string> &column_names,
uint32_t num_parallel_workers) { const std::vector<DataType> &column_types, int64_t dataset_len,
auto gen = const py::handle &sampler, uint32_t num_parallel_workers) {
std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len, auto gen =
toSamplerObj(sampler), num_parallel_workers); std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len,
THROW_IF_ERROR(gen->ValidateParams()); toSamplerObj(sampler), num_parallel_workers);
return gen; THROW_IF_ERROR(gen->ValidateParams());
})) return gen;
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema, }))
int64_t dataset_len, py::handle sampler, uint32_t num_parallel_workers) { .def(py::init([](const py::function &generator_function, const std::shared_ptr<SchemaObj> &schema,
int64_t dataset_len, const py::handle &sampler, uint32_t num_parallel_workers) {
auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len, auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len,
toSamplerObj(sampler), num_parallel_workers); toSamplerObj(sampler), num_parallel_workers);
THROW_IF_ERROR(gen->ValidateParams()); THROW_IF_ERROR(gen->ValidateParams());
@ -319,8 +326,8 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>( (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
*m, "ImageFolderNode", "to create an ImageFolderNode") *m, "ImageFolderNode", "to create an ImageFolderNode")
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions, .def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler,
py::dict class_indexing) { const py::list &extensions, const py::dict &class_indexing) {
// Don't update recursive to true // Don't update recursive to true
bool recursive = false; // Will be removed in future PR bool recursive = false; // Will be removed in future PR
auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler), auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
@ -334,18 +341,20 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) { PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) {
(void)py::class_<IMDBNode, DatasetNode, std::shared_ptr<IMDBNode>>(*m, "IMDBNode", (void)py::class_<IMDBNode, DatasetNode, std::shared_ptr<IMDBNode>>(*m, "IMDBNode",
"to create an IMDBNode") "to create an IMDBNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto imdb = std::make_shared<IMDBNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
THROW_IF_ERROR(imdb->ValidateParams()); auto imdb = std::make_shared<IMDBNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return imdb; THROW_IF_ERROR(imdb->ValidateParams());
})); return imdb;
}));
})); }));
PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) { PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
(void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>( (void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>(
*m, "IWSLT2016Node", "to create an IWSLT2016Node") *m, "IWSLT2016Node", "to create an IWSLT2016Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::vector<std::string> language_pair, .def(py::init([](const std::string &dataset_dir, const std::string &usage,
std::string valid_set, std::string test_set, int64_t num_samples, int32_t shuffle, const std::vector<std::string> &language_pair, const std::string &valid_set,
const std::string &test_set, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) { int32_t num_shards, int32_t shard_id) {
std::shared_ptr<IWSLT2016Node> iwslt2016 = std::make_shared<IWSLT2016Node>( std::shared_ptr<IWSLT2016Node> iwslt2016 = std::make_shared<IWSLT2016Node>(
dataset_dir, usage, language_pair, valid_set, test_set, num_samples, toShuffleMode(shuffle), dataset_dir, usage, language_pair, valid_set, test_set, num_samples, toShuffleMode(shuffle),
@ -358,8 +367,9 @@ PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) { PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) {
(void)py::class_<IWSLT2017Node, DatasetNode, std::shared_ptr<IWSLT2017Node>>( (void)py::class_<IWSLT2017Node, DatasetNode, std::shared_ptr<IWSLT2017Node>>(
*m, "IWSLT2017Node", "to create an IWSLT2017Node") *m, "IWSLT2017Node", "to create an IWSLT2017Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::vector<std::string> language_pair, .def(py::init([](const std::string &dataset_dir, const std::string &usage,
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) { const std::vector<std::string> &language_pair, int64_t num_samples,
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<IWSLT2017Node> iwslt2017 = std::shared_ptr<IWSLT2017Node> iwslt2017 =
std::make_shared<IWSLT2017Node>(dataset_dir, usage, language_pair, num_samples, std::make_shared<IWSLT2017Node>(dataset_dir, usage, language_pair, num_samples,
toShuffleMode(shuffle), num_shards, shard_id, nullptr); toShuffleMode(shuffle), num_shards, shard_id, nullptr);
@ -371,17 +381,18 @@ PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) { PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) {
(void)py::class_<KMnistNode, DatasetNode, std::shared_ptr<KMnistNode>>(*m, "KMnistNode", (void)py::class_<KMnistNode, DatasetNode, std::shared_ptr<KMnistNode>>(*m, "KMnistNode",
"to create a KMnistNode") "to create a KMnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto kmnist = std::make_shared<KMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
THROW_IF_ERROR(kmnist->ValidateParams()); auto kmnist = std::make_shared<KMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return kmnist; THROW_IF_ERROR(kmnist->ValidateParams());
})); return kmnist;
}));
})); }));
PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) { PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
(void)py::class_<LJSpeechNode, DatasetNode, std::shared_ptr<LJSpeechNode>>(*m, "LJSpeechNode", (void)py::class_<LJSpeechNode, DatasetNode, std::shared_ptr<LJSpeechNode>>(*m, "LJSpeechNode",
"to create a LJSpeechNode") "to create a LJSpeechNode")
.def(py::init([](std::string dataset_dir, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const py::handle &sampler) {
auto lj_speech = std::make_shared<LJSpeechNode>(dataset_dir, toSamplerObj(sampler), nullptr); auto lj_speech = std::make_shared<LJSpeechNode>(dataset_dir, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(lj_speech->ValidateParams()); THROW_IF_ERROR(lj_speech->ValidateParams());
return lj_speech; return lj_speech;
@ -391,8 +402,8 @@ PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) { PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode", (void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode") "to create a ManifestNode")
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler, .def(py::init([](const std::string &dataset_file, const std::string &usage,
py::dict class_indexing, bool decode) { const py::handle &sampler, const py::dict &class_indexing, bool decode) {
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler), auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
toStringMap(class_indexing), decode, nullptr); toStringMap(class_indexing), decode, nullptr);
THROW_IF_ERROR(manifest->ValidateParams()); THROW_IF_ERROR(manifest->ValidateParams());
@ -400,50 +411,52 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
})); }));
})); }));
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) { PYBIND_REGISTER(
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode", MindDataNode, 2, ([](const py::module *m) {
"to create a MindDataNode") (void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
.def(py::init([](std::string dataset_file, py::list columns_list, py::handle sampler, "to create a MindDataNode")
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) { .def(py::init([](const std::string &dataset_file, const py::list &columns_list, const py::handle &sampler,
nlohmann::json padded_sample_json; const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
std::map<std::string, std::string> sample_bytes; nlohmann::json padded_sample_json;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes)); std::map<std::string, std::string> sample_bytes;
auto minddata = std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list), THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
toSamplerObj(sampler, true), padded_sample_json, auto minddata =
num_padded, shuffle_mode, nullptr); std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list), toSamplerObj(sampler, true),
minddata->SetSampleBytes(&sample_bytes); padded_sample_json, num_padded, shuffle_mode, nullptr);
THROW_IF_ERROR(minddata->ValidateParams()); minddata->SetSampleBytes(&sample_bytes);
return minddata; THROW_IF_ERROR(minddata->ValidateParams());
})) return minddata;
.def(py::init([](py::list dataset_file, py::list columns_list, py::handle sampler, }))
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) { .def(py::init([](const py::list &dataset_file, const py::list &columns_list, const py::handle &sampler,
nlohmann::json padded_sample_json; const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
std::map<std::string, std::string> sample_bytes; nlohmann::json padded_sample_json;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes)); std::map<std::string, std::string> sample_bytes;
auto minddata = std::make_shared<MindDataNode>( THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
toStringVector(dataset_file), toStringVector(columns_list), toSamplerObj(sampler, true), auto minddata = std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list),
padded_sample_json, num_padded, shuffle_mode, nullptr); toSamplerObj(sampler, true), padded_sample_json, num_padded,
minddata->SetSampleBytes(&sample_bytes); shuffle_mode, nullptr);
THROW_IF_ERROR(minddata->ValidateParams()); minddata->SetSampleBytes(&sample_bytes);
return minddata; THROW_IF_ERROR(minddata->ValidateParams());
})); return minddata;
})); }));
}));
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode", (void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
"to create an MnistNode") "to create an MnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
THROW_IF_ERROR(mnist->ValidateParams()); auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return mnist; THROW_IF_ERROR(mnist->ValidateParams());
})); return mnist;
}));
})); }));
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) { PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>( (void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
*m, "PennTreebankNode", "to create a PennTreebankNode") *m, "PennTreebankNode", "to create a PennTreebankNode")
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int32_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto penn_treebank = std::make_shared<PennTreebankNode>( auto penn_treebank = std::make_shared<PennTreebankNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(penn_treebank->ValidateParams()); THROW_IF_ERROR(penn_treebank->ValidateParams());
@ -454,7 +467,8 @@ PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) { PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
(void)py::class_<PhotoTourNode, DatasetNode, std::shared_ptr<PhotoTourNode>>( (void)py::class_<PhotoTourNode, DatasetNode, std::shared_ptr<PhotoTourNode>>(
*m, "PhotoTourNode", "to create a PhotoTourNode") *m, "PhotoTourNode", "to create a PhotoTourNode")
.def(py::init([](std::string dataset_dir, std::string name, std::string usage, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const std::string &name, const std::string &usage,
const py::handle &sampler) {
auto photo_tour = auto photo_tour =
std::make_shared<PhotoTourNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr); std::make_shared<PhotoTourNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(photo_tour->ValidateParams()); THROW_IF_ERROR(photo_tour->ValidateParams());
@ -465,19 +479,20 @@ PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(Places365Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Places365Node, 2, ([](const py::module *m) {
(void)py::class_<Places365Node, DatasetNode, std::shared_ptr<Places365Node>>( (void)py::class_<Places365Node, DatasetNode, std::shared_ptr<Places365Node>>(
*m, "Places365Node", "to create a Places365Node") *m, "Places365Node", "to create a Places365Node")
.def(py::init( .def(py::init([](const std::string &dataset_dir, const std::string &usage, bool small, bool decode,
[](std::string dataset_dir, std::string usage, bool small, bool decode, py::handle sampler) { const py::handle &sampler) {
auto places365 = std::make_shared<Places365Node>(dataset_dir, usage, small, decode, auto places365 = std::make_shared<Places365Node>(dataset_dir, usage, small, decode,
toSamplerObj(sampler), nullptr); toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(places365->ValidateParams()); THROW_IF_ERROR(places365->ValidateParams());
return places365; return places365;
})); }));
})); }));
PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) { PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
(void)py::class_<QMnistNode, DatasetNode, std::shared_ptr<QMnistNode>>(*m, "QMnistNode", (void)py::class_<QMnistNode, DatasetNode, std::shared_ptr<QMnistNode>>(*m, "QMnistNode",
"to create a QMnistNode") "to create a QMnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, bool compat, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const std::string &usage, bool compat,
const py::handle &sampler) {
auto qmnist = auto qmnist =
std::make_shared<QMnistNode>(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr); std::make_shared<QMnistNode>(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(qmnist->ValidateParams()); THROW_IF_ERROR(qmnist->ValidateParams());
@ -485,27 +500,25 @@ PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
})); }));
})); }));
PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) { PYBIND_REGISTER(
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", RandomNode, 2, ([](const py::module *m) {
"to create a RandomNode") (void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) { .def(py::init([](int32_t total_rows, const std::shared_ptr<SchemaObj> &schema, const py::list &columns_list) {
auto random_node = auto random_node = std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr); THROW_IF_ERROR(random_node->ValidateParams());
THROW_IF_ERROR(random_node->ValidateParams()); return random_node;
return random_node; }))
})) .def(py::init([](int32_t total_rows, const std::string &schema, const py::list &columns_list) {
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) { auto random_node = std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
auto random_node = THROW_IF_ERROR(random_node->ValidateParams());
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr); return random_node;
THROW_IF_ERROR(random_node->ValidateParams()); }));
return random_node; }));
}));
}));
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) { PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode", (void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
"to create an SBUNode") "to create an SBUNode")
.def(py::init([](std::string dataset_dir, bool decode, const py::handle &sampler) { .def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler) {
auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr); auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(sbu->ValidateParams()); THROW_IF_ERROR(sbu->ValidateParams());
return sbu; return sbu;
@ -515,7 +528,7 @@ PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(SemeionNode, 2, ([](const py::module *m) { PYBIND_REGISTER(SemeionNode, 2, ([](const py::module *m) {
(void)py::class_<SemeionNode, DatasetNode, std::shared_ptr<SemeionNode>>(*m, "SemeionNode", (void)py::class_<SemeionNode, DatasetNode, std::shared_ptr<SemeionNode>>(*m, "SemeionNode",
"to create a SemeionNode") "to create a SemeionNode")
.def(py::init([](std::string dataset_dir, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const py::handle &sampler) {
auto semeion = std::make_shared<SemeionNode>(dataset_dir, toSamplerObj(sampler), nullptr); auto semeion = std::make_shared<SemeionNode>(dataset_dir, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(semeion->ValidateParams()); THROW_IF_ERROR(semeion->ValidateParams());
return semeion; return semeion;
@ -525,8 +538,8 @@ PYBIND_REGISTER(SemeionNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(SogouNewsNode, 2, ([](const py::module *m) { PYBIND_REGISTER(SogouNewsNode, 2, ([](const py::module *m) {
(void)py::class_<SogouNewsNode, DatasetNode, std::shared_ptr<SogouNewsNode>>( (void)py::class_<SogouNewsNode, DatasetNode, std::shared_ptr<SogouNewsNode>>(
*m, "SogouNewsNode", "to create a SogouNewsNode") *m, "SogouNewsNode", "to create a SogouNewsNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto sogou_news = std::make_shared<SogouNewsNode>( auto sogou_news = std::make_shared<SogouNewsNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(sogou_news->ValidateParams()); THROW_IF_ERROR(sogou_news->ValidateParams());
@ -537,41 +550,44 @@ PYBIND_REGISTER(SogouNewsNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) { PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) {
(void)py::class_<SpeechCommandsNode, DatasetNode, std::shared_ptr<SpeechCommandsNode>>( (void)py::class_<SpeechCommandsNode, DatasetNode, std::shared_ptr<SpeechCommandsNode>>(
*m, "SpeechCommandsNode", "to create a SpeechCommandsNode") *m, "SpeechCommandsNode", "to create a SpeechCommandsNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto speech_commands = py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
std::make_shared<SpeechCommandsNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); auto speech_commands =
THROW_IF_ERROR(speech_commands->ValidateParams()); std::make_shared<SpeechCommandsNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return speech_commands; THROW_IF_ERROR(speech_commands->ValidateParams());
})); return speech_commands;
}));
})); }));
PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) { PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
(void)py::class_<STL10Node, DatasetNode, std::shared_ptr<STL10Node>>(*m, "STL10Node", (void)py::class_<STL10Node, DatasetNode, std::shared_ptr<STL10Node>>(*m, "STL10Node",
"to create a STL10Node") "to create a STL10Node")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { .def(
auto stl10 = std::make_shared<STL10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr); py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
THROW_IF_ERROR(stl10->ValidateParams()); auto stl10 = std::make_shared<STL10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
return stl10; THROW_IF_ERROR(stl10->ValidateParams());
})); return stl10;
}));
})); }));
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) { PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode", (void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
"to create a TedliumNode") "to create a TedliumNode")
.def(py::init([](std::string dataset_dir, std::string release, std::string usage, .def(
std::string extensions, py::handle sampler) { py::init([](const std::string &dataset_dir, const std::string &release, const std::string &usage,
auto tedlium = std::make_shared<TedliumNode>(dataset_dir, release, usage, extensions, const std::string &extensions, const py::handle &sampler) {
toSamplerObj(sampler), nullptr); auto tedlium = std::make_shared<TedliumNode>(dataset_dir, release, usage, extensions,
THROW_IF_ERROR(tedlium->ValidateParams()); toSamplerObj(sampler), nullptr);
return tedlium; THROW_IF_ERROR(tedlium->ValidateParams());
})); return tedlium;
}));
})); }));
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) { PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode", (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode") "to create a TextFileNode")
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards, .def(py::init([](const py::list &dataset_files, int32_t num_samples, int32_t shuffle,
int32_t shard_id) { int32_t num_shards, int32_t shard_id) {
std::shared_ptr<TextFileNode> textfile_node = std::shared_ptr<TextFileNode> textfile_node =
std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples, std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
toShuffleMode(shuffle), num_shards, shard_id, nullptr); toShuffleMode(shuffle), num_shards, shard_id, nullptr);
@ -583,8 +599,8 @@ PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) { PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode", (void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
"to create a TFRecordNode") "to create a TFRecordNode")
.def(py::init([](const py::list dataset_files, std::shared_ptr<SchemaObj> schema, .def(py::init([](const py::list &dataset_files, const std::shared_ptr<SchemaObj> &schema,
const py::list columns_list, int64_t num_samples, int32_t shuffle, const py::list &columns_list, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows) { int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
@ -592,9 +608,9 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
THROW_IF_ERROR(tfrecord->ValidateParams()); THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord; return tfrecord;
})) }))
.def(py::init([](const py::list dataset_files, std::string schema, const py::list columns_list, .def(py::init([](const py::list &dataset_files, const std::string &schema,
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, const py::list &columns_list, int64_t num_samples, int32_t shuffle,
bool shard_equal_rows) { int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr); toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
@ -606,8 +622,8 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) { PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) {
(void)py::class_<UDPOSNode, DatasetNode, std::shared_ptr<UDPOSNode>>(*m, "UDPOSNode", (void)py::class_<UDPOSNode, DatasetNode, std::shared_ptr<UDPOSNode>>(*m, "UDPOSNode",
"to create an UDPOSNode") "to create an UDPOSNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<UDPOSNode> udpos = std::make_shared<UDPOSNode>( std::shared_ptr<UDPOSNode> udpos = std::make_shared<UDPOSNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(udpos->ValidateParams()); THROW_IF_ERROR(udpos->ValidateParams());
@ -618,8 +634,8 @@ PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) { PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
(void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode", (void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
"to create an USPSNode") "to create an USPSNode")
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int32_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle), auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, nullptr); num_shards, shard_id, nullptr);
THROW_IF_ERROR(usps->ValidateParams()); THROW_IF_ERROR(usps->ValidateParams());
@ -629,7 +645,7 @@ PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode") (void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def(py::init([](std::string dataset_dir, std::string task, std::string usage, .def(py::init([](const std::string &dataset_dir, const std::string &task, const std::string &usage,
const py::dict &class_indexing, bool decode, const py::handle &sampler, const py::dict &class_indexing, bool decode, const py::handle &sampler,
bool extra_metadata) { bool extra_metadata) {
std::shared_ptr<VOCNode> voc = std::shared_ptr<VOCNode> voc =
@ -643,7 +659,8 @@ PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(WIDERFaceNode, 2, ([](const py::module *m) { PYBIND_REGISTER(WIDERFaceNode, 2, ([](const py::module *m) {
(void)py::class_<WIDERFaceNode, DatasetNode, std::shared_ptr<WIDERFaceNode>>( (void)py::class_<WIDERFaceNode, DatasetNode, std::shared_ptr<WIDERFaceNode>>(
*m, "WIDERFaceNode", "to create a WIDERFaceNode") *m, "WIDERFaceNode", "to create a WIDERFaceNode")
.def(py::init([](std::string dataset_dir, std::string usage, bool decode, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
const py::handle &sampler) {
auto wider_face = auto wider_face =
std::make_shared<WIDERFaceNode>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr); std::make_shared<WIDERFaceNode>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(wider_face->ValidateParams()); THROW_IF_ERROR(wider_face->ValidateParams());
@ -654,8 +671,8 @@ PYBIND_REGISTER(WIDERFaceNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) { PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) {
(void)py::class_<WikiTextNode, DatasetNode, std::shared_ptr<WikiTextNode>>(*m, "WikiTextNode", (void)py::class_<WikiTextNode, DatasetNode, std::shared_ptr<WikiTextNode>>(*m, "WikiTextNode",
"to create a WikiTextNode") "to create a WikiTextNode")
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int32_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto wiki_text = std::make_shared<WikiTextNode>( auto wiki_text = std::make_shared<WikiTextNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(wiki_text->ValidateParams()); THROW_IF_ERROR(wiki_text->ValidateParams());
@ -666,8 +683,8 @@ PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) { PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) {
(void)py::class_<YahooAnswersNode, DatasetNode, std::shared_ptr<YahooAnswersNode>>( (void)py::class_<YahooAnswersNode, DatasetNode, std::shared_ptr<YahooAnswersNode>>(
*m, "YahooAnswersNode", "to create a YahooAnswersNode") *m, "YahooAnswersNode", "to create a YahooAnswersNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto yahoo_answers = std::make_shared<YahooAnswersNode>( auto yahoo_answers = std::make_shared<YahooAnswersNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(yahoo_answers->ValidateParams()); THROW_IF_ERROR(yahoo_answers->ValidateParams());
@ -678,8 +695,8 @@ PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(YelpReviewNode, 2, ([](const py::module *m) { PYBIND_REGISTER(YelpReviewNode, 2, ([](const py::module *m) {
(void)py::class_<YelpReviewNode, DatasetNode, std::shared_ptr<YelpReviewNode>>( (void)py::class_<YelpReviewNode, DatasetNode, std::shared_ptr<YelpReviewNode>>(
*m, "YelpReviewNode", "to create a YelpReviewNode") *m, "YelpReviewNode", "to create a YelpReviewNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<YelpReviewNode> yelp_review = std::make_shared<YelpReviewNode>( std::shared_ptr<YelpReviewNode> yelp_review = std::make_shared<YelpReviewNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(yelp_review->ValidateParams()); THROW_IF_ERROR(yelp_review->ValidateParams());
@ -690,7 +707,7 @@ PYBIND_REGISTER(YelpReviewNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) { PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) {
(void)py::class_<YesNoNode, DatasetNode, std::shared_ptr<YesNoNode>>(*m, "YesNoNode", (void)py::class_<YesNoNode, DatasetNode, std::shared_ptr<YesNoNode>>(*m, "YesNoNode",
"to create a YesNoNode") "to create a YesNoNode")
.def(py::init([](std::string dataset_dir, py::handle sampler) { .def(py::init([](const std::string &dataset_dir, const py::handle &sampler) {
auto yes_no = std::make_shared<YesNoNode>(dataset_dir, toSamplerObj(sampler), nullptr); auto yes_no = std::make_shared<YesNoNode>(dataset_dir, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(yes_no->ValidateParams()); THROW_IF_ERROR(yes_no->ValidateParams());
return yes_no; return yes_no;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "minddata/dataset/api/python/pybind_conversion.h" #include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/transforms.h"
@ -112,7 +113,7 @@ PYBIND_REGISTER(BoundingBoxAugmentOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::BoundingBoxAugmentOperation, TensorOperation, (void)py::class_<vision::BoundingBoxAugmentOperation, TensorOperation,
std::shared_ptr<vision::BoundingBoxAugmentOperation>>(*m, std::shared_ptr<vision::BoundingBoxAugmentOperation>>(*m,
"BoundingBoxAugmentOperation") "BoundingBoxAugmentOperation")
.def(py::init([](const py::object transform, float ratio) { .def(py::init([](const py::object &transform, float ratio) {
auto bounding_box_augment = std::make_shared<vision::BoundingBoxAugmentOperation>( auto bounding_box_augment = std::make_shared<vision::BoundingBoxAugmentOperation>(
std::move(toTensorOperation(transform)), ratio); std::move(toTensorOperation(transform)), ratio);
THROW_IF_ERROR(bounding_box_augment->ValidateParams()); THROW_IF_ERROR(bounding_box_augment->ValidateParams());
@ -207,7 +208,7 @@ PYBIND_REGISTER(
GaussianBlurOperation, 1, ([](const py::module *m) { GaussianBlurOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::GaussianBlurOperation, TensorOperation, std::shared_ptr<vision::GaussianBlurOperation>>( (void)py::class_<vision::GaussianBlurOperation, TensorOperation, std::shared_ptr<vision::GaussianBlurOperation>>(
*m, "GaussianBlurOperation") *m, "GaussianBlurOperation")
.def(py::init([](std::vector<int32_t> kernel_size, std::vector<float> sigma) { .def(py::init([](const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma) {
auto gaussian_blur = std::make_shared<vision::GaussianBlurOperation>(kernel_size, sigma); auto gaussian_blur = std::make_shared<vision::GaussianBlurOperation>(kernel_size, sigma);
THROW_IF_ERROR(gaussian_blur->ValidateParams()); THROW_IF_ERROR(gaussian_blur->ValidateParams());
return gaussian_blur; return gaussian_blur;
@ -504,23 +505,25 @@ PYBIND_REGISTER(RandomResizeWithBBoxOperation, 1, ([](const py::module *m) {
})); }));
})); }));
PYBIND_REGISTER(RandomRotationOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(
(void)py::class_<vision::RandomRotationOperation, TensorOperation, RandomRotationOperation, 1, ([](const py::module *m) {
std::shared_ptr<vision::RandomRotationOperation>>(*m, "RandomRotationOperation") (void)
.def(py::init([](std::vector<float> degrees, InterpolationMode interpolation_mode, bool expand, py::class_<vision::RandomRotationOperation, TensorOperation, std::shared_ptr<vision::RandomRotationOperation>>(
std::vector<float> center, std::vector<uint8_t> fill_value) { *m, "RandomRotationOperation")
auto random_rotation = std::make_shared<vision::RandomRotationOperation>( .def(py::init([](const std::vector<float> &degrees, InterpolationMode interpolation_mode, bool expand,
degrees, interpolation_mode, expand, center, fill_value); const std::vector<float> &center, const std::vector<uint8_t> &fill_value) {
THROW_IF_ERROR(random_rotation->ValidateParams()); auto random_rotation =
return random_rotation; std::make_shared<vision::RandomRotationOperation>(degrees, interpolation_mode, expand, center, fill_value);
})); THROW_IF_ERROR(random_rotation->ValidateParams());
})); return random_rotation;
}));
}));
PYBIND_REGISTER( PYBIND_REGISTER(
RandomSelectSubpolicyOperation, 1, ([](const py::module *m) { RandomSelectSubpolicyOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation, (void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation,
std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation") std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation")
.def(py::init([](const py::list py_policy) { .def(py::init([](const py::list &py_policy) {
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy; std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy;
for (auto &py_sub : py_policy) { for (auto &py_sub : py_policy) {
cpp_policy.push_back({}); cpp_policy.push_back({});
@ -643,8 +646,8 @@ PYBIND_REGISTER(RgbToBgrOperation, 1, ([](const py::module *m) {
PYBIND_REGISTER(RotateOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(RotateOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::RotateOperation, TensorOperation, std::shared_ptr<vision::RotateOperation>>( (void)py::class_<vision::RotateOperation, TensorOperation, std::shared_ptr<vision::RotateOperation>>(
*m, "RotateOperation") *m, "RotateOperation")
.def(py::init([](float degrees, InterpolationMode resample, bool expand, std::vector<float> center, .def(py::init([](float degrees, InterpolationMode resample, bool expand,
std::vector<uint8_t> fill_value) { const std::vector<float> &center, const std::vector<uint8_t> &fill_value) {
auto rotate = auto rotate =
std::make_shared<vision::RotateOperation>(degrees, resample, expand, center, fill_value); std::make_shared<vision::RotateOperation>(degrees, resample, expand, center, fill_value);
THROW_IF_ERROR(rotate->ValidateParams()); THROW_IF_ERROR(rotate->ValidateParams());
@ -694,7 +697,7 @@ PYBIND_REGISTER(
UniformAugOperation, 1, ([](const py::module *m) { UniformAugOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::UniformAugOperation, TensorOperation, std::shared_ptr<vision::UniformAugOperation>>( (void)py::class_<vision::UniformAugOperation, TensorOperation, std::shared_ptr<vision::UniformAugOperation>>(
*m, "UniformAugOperation") *m, "UniformAugOperation")
.def(py::init([](const py::list transforms, int32_t num_ops) { .def(py::init([](const py::list &transforms, int32_t num_ops) {
auto uniform_aug = auto uniform_aug =
std::make_shared<vision::UniformAugOperation>(std::move(toTensorOperations(transforms)), num_ops); std::make_shared<vision::UniformAugOperation>(std::move(toTensorOperations(transforms)), num_ops);
THROW_IF_ERROR(uniform_aug->ValidateParams()); THROW_IF_ERROR(uniform_aug->ValidateParams());

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,9 +15,11 @@
*/ */
#include "minddata/dataset/include/dataset/samplers.h" #include "minddata/dataset/include/dataset/samplers.h"
#include <utility>
#include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/pk_sampler_ir.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/pk_sampler_ir.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/prebuilt_sampler_ir.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/random_sampler_ir.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/random_sampler_ir.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h"
@ -27,7 +29,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// DistributedSampler // DistributedSampler
DistributedSampler::DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, DistributedSampler::DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist) uint32_t seed, int64_t offset, bool even_dist)
@ -69,7 +70,7 @@ std::shared_ptr<SamplerObj> SequentialSampler::Parse() const {
} }
// SubsetSampler // SubsetSampler
SubsetSampler::SubsetSampler(std::vector<int64_t> indices, int64_t num_samples) SubsetSampler::SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples)
: indices_(indices), num_samples_(num_samples) {} : indices_(indices), num_samples_(num_samples) {}
std::shared_ptr<SamplerObj> SubsetSampler::Parse() const { std::shared_ptr<SamplerObj> SubsetSampler::Parse() const {
@ -77,7 +78,7 @@ std::shared_ptr<SamplerObj> SubsetSampler::Parse() const {
} }
// SubsetRandomSampler // SubsetRandomSampler
SubsetRandomSampler::SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples)
: SubsetSampler(indices, num_samples) {} : SubsetSampler(indices, num_samples) {}
std::shared_ptr<SamplerObj> SubsetRandomSampler::Parse() const { std::shared_ptr<SamplerObj> SubsetRandomSampler::Parse() const {
@ -85,12 +86,11 @@ std::shared_ptr<SamplerObj> SubsetRandomSampler::Parse() const {
} }
// WeightedRandomSampler // WeightedRandomSampler
WeightedRandomSampler::WeightedRandomSampler(std::vector<double> weights, int64_t num_samples, bool replacement) WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement)
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {} : weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
std::shared_ptr<SamplerObj> WeightedRandomSampler::Parse() const { std::shared_ptr<SamplerObj> WeightedRandomSampler::Parse() const {
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,23 +13,21 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/include/dataset/text.h"
#include <unistd.h> #include <unistd.h>
#include <fstream> #include <fstream>
#include <regex> #include <regex>
#include "utils/file_utils.h"
#include "minddata/dataset/include/dataset/text.h"
#include "minddata/dataset/core/type_id.h" #include "minddata/dataset/core/type_id.h"
#include "minddata/dataset/text/ir/kernels/text_ir.h" #include "minddata/dataset/text/ir/kernels/text_ir.h"
#include "mindspore/core/ir/dtype/type_id.h" #include "mindspore/core/ir/dtype/type_id.h"
#include "utils/file_utils.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Transform operations for text. // Transform operations for text.
namespace text { namespace text {
constexpr size_t size_two = 2; constexpr size_t size_two = 2;
constexpr size_t size_three = 3; constexpr size_t size_three = 3;
constexpr int64_t value_one = 1; constexpr int64_t value_one = 1;
@ -104,7 +102,7 @@ std::shared_ptr<TensorOperation> BertTokenizer::Parse() {
} }
// CaseFold // CaseFold
CaseFold::CaseFold() {} CaseFold::CaseFold() = default;
std::shared_ptr<TensorOperation> CaseFold::Parse() { return std::make_shared<CaseFoldOperation>(); } std::shared_ptr<TensorOperation> CaseFold::Parse() { return std::make_shared<CaseFoldOperation>(); }
#endif #endif
@ -170,6 +168,7 @@ Status JiebaTokenizer::AddDictChar(const std::vector<char> &file_path) {
Status JiebaTokenizer::ParserFile(const std::string &file_path, Status JiebaTokenizer::ParserFile(const std::string &file_path,
std::vector<std::pair<std::string, int64_t>> *const user_dict) { std::vector<std::pair<std::string, int64_t>> *const user_dict) {
RETURN_UNEXPECTED_IF_NULL(user_dict);
auto realpath = FileUtils::GetRealPath(file_path.data()); auto realpath = FileUtils::GetRealPath(file_path.data());
if (!realpath.has_value()) { if (!realpath.has_value()) {
std::string err_msg = "Get real path failed, path: " + file_path; std::string err_msg = "Get real path failed, path: " + file_path;
@ -193,7 +192,7 @@ Status JiebaTokenizer::ParserFile(const std::string &file_path,
if (tokens.size() == size_two) { if (tokens.size() == size_two) {
(void)user_dict->emplace_back(tokens.str(value_one), 0); (void)user_dict->emplace_back(tokens.str(value_one), 0);
} else if (tokens.size() == size_three) { } else if (tokens.size() == size_three) {
(void)user_dict->emplace_back(tokens.str(value_one), strtoll(tokens.str(value_two).c_str(), NULL, 0)); (void)user_dict->emplace_back(tokens.str(value_one), strtoll(tokens.str(value_two).c_str(), nullptr, 0));
} else { } else {
continue; continue;
} }
@ -201,6 +200,7 @@ Status JiebaTokenizer::ParserFile(const std::string &file_path,
continue; continue;
} }
} }
ifs.close();
MS_LOG(INFO) << "JiebaTokenizer::AddDict: The size of user input dictionary is: " << user_dict->size(); MS_LOG(INFO) << "JiebaTokenizer::AddDict: The size of user input dictionary is: " << user_dict->size();
MS_LOG(INFO) << "Valid rows in input dictionary (Maximum of first 10 rows are shown.):"; MS_LOG(INFO) << "Valid rows in input dictionary (Maximum of first 10 rows are shown.):";
for (std::size_t i = 0; i != user_dict->size(); ++i) { for (std::size_t i = 0; i != user_dict->size(); ++i) {
@ -367,7 +367,8 @@ struct ToVectors::Data {
bool lower_case_backup_; bool lower_case_backup_;
}; };
ToVectors::ToVectors(const std::shared_ptr<Vectors> &vectors, const std::vector<float> unk_init, bool lower_case_backup) ToVectors::ToVectors(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init,
bool lower_case_backup)
: data_(std::make_shared<Data>(vectors, unk_init, lower_case_backup)) {} : data_(std::make_shared<Data>(vectors, unk_init, lower_case_backup)) {}
std::shared_ptr<TensorOperation> ToVectors::Parse() { std::shared_ptr<TensorOperation> ToVectors::Parse() {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,17 +18,14 @@
#include <algorithm> #include <algorithm>
#include "mindspore/ccsrc/minddata/dataset/core/type_id.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "minddata/dataset/core/type_id.h" #include "minddata/dataset/core/type_id.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h" #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Transform operations for data. // Transform operations for data.
namespace transforms { namespace transforms {
// API CLASS FOR DATA TRANSFORM OPERATIONS // API CLASS FOR DATA TRANSFORM OPERATIONS
// (In alphabetical order) // (In alphabetical order)
@ -46,7 +43,7 @@ Compose::Compose(const std::vector<TensorTransform *> &transforms) : data_(std::
Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) { Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
[](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> { [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
return op != nullptr ? op->Parse() : nullptr; return op != nullptr ? op->Parse() : nullptr;
}); });
} }
@ -92,7 +89,7 @@ std::shared_ptr<TensorOperation> Concatenate::Parse() {
} }
// Constructor to Duplicate // Constructor to Duplicate
Duplicate::Duplicate() {} Duplicate::Duplicate() = default;
std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); } std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
@ -202,7 +199,7 @@ RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, doubl
RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob) RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
: data_(std::make_shared<Data>()) { : data_(std::make_shared<Data>()) {
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
[](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> { [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
return op != nullptr ? op->Parse() : nullptr; return op != nullptr ? op->Parse() : nullptr;
}); });
data_->prob_ = prob; data_->prob_ = prob;
@ -234,7 +231,7 @@ RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) : d
RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms) RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
: data_(std::make_shared<Data>()) { : data_(std::make_shared<Data>()) {
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
[](const std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> { [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
return op != nullptr ? op->Parse() : nullptr; return op != nullptr ? op->Parse() : nullptr;
}); });
} }
@ -278,7 +275,7 @@ TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>
std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); } std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
// Constructor to Unique // Constructor to Unique
Unique::Unique() {} Unique::Unique() = default;
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); } std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,7 +20,6 @@
#include "minddata/dataset/kernels/ir/vision/ascend_vision_ir.h" #include "minddata/dataset/kernels/ir/vision/ascend_vision_ir.h"
#endif #endif
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h" #include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/affine_ir.h" #include "minddata/dataset/kernels/ir/vision/affine_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h" #include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
@ -88,11 +87,8 @@
#endif #endif
#include "minddata/dataset/kernels/ir/validators.h" #include "minddata/dataset/kernels/ir/validators.h"
// Kernel image headers (in alphabetical order)
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Transform operations for computer vision. // Transform operations for computer vision.
namespace vision { namespace vision {
// CONSTRUCTORS FOR API CLASSES TO CREATE VISION TENSOR TRANSFORM OPERATIONS // CONSTRUCTORS FOR API CLASSES TO CREATE VISION TENSOR TRANSFORM OPERATIONS
@ -163,7 +159,7 @@ struct AutoContrast::Data {
std::vector<uint32_t> ignore_; std::vector<uint32_t> ignore_;
}; };
AutoContrast::AutoContrast(float cutoff, std::vector<uint32_t> ignore) AutoContrast::AutoContrast(float cutoff, const std::vector<uint32_t> &ignore)
: data_(std::make_shared<Data>(cutoff, ignore)) {} : data_(std::make_shared<Data>(cutoff, ignore)) {}
std::shared_ptr<TensorOperation> AutoContrast::Parse() { std::shared_ptr<TensorOperation> AutoContrast::Parse() {
@ -187,7 +183,7 @@ BoundingBoxAugment::BoundingBoxAugment(const std::shared_ptr<TensorTransform> &t
data_->ratio_ = ratio; data_->ratio_ = ratio;
} }
BoundingBoxAugment::BoundingBoxAugment(const std::reference_wrapper<TensorTransform> transform, float ratio) BoundingBoxAugment::BoundingBoxAugment(const std::reference_wrapper<TensorTransform> &transform, float ratio)
: data_(std::make_shared<Data>()) { : data_(std::make_shared<Data>()) {
data_->transform_ = transform.get().Parse(); data_->transform_ = transform.get().Parse();
data_->ratio_ = ratio; data_->ratio_ = ratio;
@ -204,7 +200,7 @@ struct CenterCrop::Data {
std::vector<int32_t> size_; std::vector<int32_t> size_;
}; };
CenterCrop::CenterCrop(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {} CenterCrop::CenterCrop(const std::vector<int32_t> &size) : data_(std::make_shared<Data>(size)) {}
std::shared_ptr<TensorOperation> CenterCrop::Parse() { return std::make_shared<CenterCropOperation>(data_->size_); } std::shared_ptr<TensorOperation> CenterCrop::Parse() { return std::make_shared<CenterCropOperation>(data_->size_); }
@ -246,7 +242,7 @@ struct Crop::Data {
std::vector<int32_t> size_; std::vector<int32_t> size_;
}; };
Crop::Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size) Crop::Crop(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size)
: data_(std::make_shared<Data>(coordinates, size)) {} : data_(std::make_shared<Data>(coordinates, size)) {}
std::shared_ptr<TensorOperation> Crop::Parse() { std::shared_ptr<TensorOperation> Crop::Parse() {
@ -313,7 +309,8 @@ struct DvppDecodeResizeJpeg::Data {
std::vector<uint32_t> resize_; std::vector<uint32_t> resize_;
}; };
DvppDecodeResizeJpeg::DvppDecodeResizeJpeg(std::vector<uint32_t> resize) : data_(std::make_shared<Data>(resize)) {} DvppDecodeResizeJpeg::DvppDecodeResizeJpeg(const std::vector<uint32_t> &resize)
: data_(std::make_shared<Data>(resize)) {}
std::shared_ptr<TensorOperation> DvppDecodeResizeJpeg::Parse() { std::shared_ptr<TensorOperation> DvppDecodeResizeJpeg::Parse() {
return std::make_shared<DvppDecodeResizeOperation>(data_->resize_); return std::make_shared<DvppDecodeResizeOperation>(data_->resize_);
@ -334,7 +331,8 @@ struct DvppDecodeResizeCropJpeg::Data {
std::vector<uint32_t> resize_; std::vector<uint32_t> resize_;
}; };
DvppDecodeResizeCropJpeg::DvppDecodeResizeCropJpeg(std::vector<uint32_t> crop, std::vector<uint32_t> resize) DvppDecodeResizeCropJpeg::DvppDecodeResizeCropJpeg(const std::vector<uint32_t> &crop,
const std::vector<uint32_t> &resize)
: data_(std::make_shared<Data>(crop, resize)) {} : data_(std::make_shared<Data>(crop, resize)) {}
std::shared_ptr<TensorOperation> DvppDecodeResizeCropJpeg::Parse() { std::shared_ptr<TensorOperation> DvppDecodeResizeCropJpeg::Parse() {
@ -365,7 +363,7 @@ std::shared_ptr<TensorOperation> DvppDecodePng::Parse(const MapTargetDevice &env
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Equalize Transform Operation. // Equalize Transform Operation.
Equalize::Equalize() {} Equalize::Equalize() = default;
std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); } std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); }
#endif // not ENABLE_ANDROID #endif // not ENABLE_ANDROID
@ -387,17 +385,17 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// HorizontalFlip Transform Operation. // HorizontalFlip Transform Operation.
HorizontalFlip::HorizontalFlip() {} HorizontalFlip::HorizontalFlip() = default;
std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); } std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); }
// HwcToChw Transform Operation. // HwcToChw Transform Operation.
HWC2CHW::HWC2CHW() {} HWC2CHW::HWC2CHW() = default;
std::shared_ptr<TensorOperation> HWC2CHW::Parse() { return std::make_shared<HwcToChwOperation>(); } std::shared_ptr<TensorOperation> HWC2CHW::Parse() { return std::make_shared<HwcToChwOperation>(); }
// Invert Transform Operation. // Invert Transform Operation.
Invert::Invert() {} Invert::Invert() = default;
std::shared_ptr<TensorOperation> Invert::Parse() { return std::make_shared<InvertOperation>(); } std::shared_ptr<TensorOperation> Invert::Parse() { return std::make_shared<InvertOperation>(); }
@ -419,7 +417,8 @@ struct Normalize::Data {
std::vector<float> std_; std::vector<float> std_;
}; };
Normalize::Normalize(std::vector<float> mean, std::vector<float> std) : data_(std::make_shared<Data>(mean, std)) {} Normalize::Normalize(const std::vector<float> &mean, const std::vector<float> &std)
: data_(std::make_shared<Data>(mean, std)) {}
std::shared_ptr<TensorOperation> Normalize::Parse() { std::shared_ptr<TensorOperation> Normalize::Parse() {
return std::make_shared<NormalizeOperation>(data_->mean_, data_->std_); return std::make_shared<NormalizeOperation>(data_->mean_, data_->std_);
@ -464,7 +463,7 @@ struct Pad::Data {
BorderType padding_mode_; BorderType padding_mode_;
}; };
Pad::Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode) Pad::Pad(const std::vector<int32_t> &padding, const std::vector<uint8_t> &fill_value, BorderType padding_mode)
: data_(std::make_shared<Data>(padding, fill_value, padding_mode)) {} : data_(std::make_shared<Data>(padding, fill_value, padding_mode)) {}
std::shared_ptr<TensorOperation> Pad::Parse() { std::shared_ptr<TensorOperation> Pad::Parse() {
@ -524,7 +523,7 @@ struct RandomAutoContrast::Data {
float probability_; float probability_;
}; };
RandomAutoContrast::RandomAutoContrast(float cutoff, std::vector<uint32_t> ignore, float prob) RandomAutoContrast::RandomAutoContrast(float cutoff, const std::vector<uint32_t> &ignore, float prob)
: data_(std::make_shared<Data>(cutoff, ignore, prob)) {} : data_(std::make_shared<Data>(cutoff, ignore, prob)) {}
std::shared_ptr<TensorOperation> RandomAutoContrast::Parse() { std::shared_ptr<TensorOperation> RandomAutoContrast::Parse() {
@ -555,8 +554,8 @@ struct RandomColorAdjust::Data {
std::vector<float> hue_; std::vector<float> hue_;
}; };
RandomColorAdjust::RandomColorAdjust(std::vector<float> brightness, std::vector<float> contrast, RandomColorAdjust::RandomColorAdjust(const std::vector<float> &brightness, const std::vector<float> &contrast,
std::vector<float> saturation, std::vector<float> hue) const std::vector<float> &saturation, const std::vector<float> &hue)
: data_(std::make_shared<Data>(brightness, contrast, saturation, hue)) {} : data_(std::make_shared<Data>(brightness, contrast, saturation, hue)) {}
std::shared_ptr<TensorOperation> RandomColorAdjust::Parse() { std::shared_ptr<TensorOperation> RandomColorAdjust::Parse() {
@ -580,8 +579,8 @@ struct RandomCrop::Data {
BorderType padding_mode_; BorderType padding_mode_;
}; };
RandomCrop::RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, RandomCrop::RandomCrop(const std::vector<int32_t> &size, const std::vector<int32_t> &padding, bool pad_if_needed,
std::vector<uint8_t> fill_value, BorderType padding_mode) const std::vector<uint8_t> &fill_value, BorderType padding_mode)
: data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {} : data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {}
std::shared_ptr<TensorOperation> RandomCrop::Parse() { std::shared_ptr<TensorOperation> RandomCrop::Parse() {
@ -601,8 +600,8 @@ struct RandomCropDecodeResize::Data {
int32_t max_attempts_; int32_t max_attempts_;
}; };
RandomCropDecodeResize::RandomCropDecodeResize(std::vector<int32_t> size, std::vector<float> scale, RandomCropDecodeResize::RandomCropDecodeResize(const std::vector<int32_t> &size, const std::vector<float> &scale,
std::vector<float> ratio, InterpolationMode interpolation, const std::vector<float> &ratio, InterpolationMode interpolation,
int32_t max_attempts) int32_t max_attempts)
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {} : data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
@ -627,8 +626,9 @@ struct RandomCropWithBBox::Data {
BorderType padding_mode_; BorderType padding_mode_;
}; };
RandomCropWithBBox::RandomCropWithBBox(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, RandomCropWithBBox::RandomCropWithBBox(const std::vector<int32_t> &size, const std::vector<int32_t> &padding,
std::vector<uint8_t> fill_value, BorderType padding_mode) bool pad_if_needed, const std::vector<uint8_t> &fill_value,
BorderType padding_mode)
: data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {} : data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {}
std::shared_ptr<TensorOperation> RandomCropWithBBox::Parse() { std::shared_ptr<TensorOperation> RandomCropWithBBox::Parse() {
@ -714,7 +714,7 @@ struct RandomResize::Data {
std::vector<int32_t> size_; std::vector<int32_t> size_;
}; };
RandomResize::RandomResize(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {} RandomResize::RandomResize(const std::vector<int32_t> &size) : data_(std::make_shared<Data>(size)) {}
std::shared_ptr<TensorOperation> RandomResize::Parse() { return std::make_shared<RandomResizeOperation>(data_->size_); } std::shared_ptr<TensorOperation> RandomResize::Parse() { return std::make_shared<RandomResizeOperation>(data_->size_); }
@ -724,7 +724,7 @@ struct RandomResizeWithBBox::Data {
std::vector<int32_t> size_; std::vector<int32_t> size_;
}; };
RandomResizeWithBBox::RandomResizeWithBBox(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {} RandomResizeWithBBox::RandomResizeWithBBox(const std::vector<int32_t> &size) : data_(std::make_shared<Data>(size)) {}
std::shared_ptr<TensorOperation> RandomResizeWithBBox::Parse() { std::shared_ptr<TensorOperation> RandomResizeWithBBox::Parse() {
return std::make_shared<RandomResizeWithBBoxOperation>(data_->size_); return std::make_shared<RandomResizeWithBBoxOperation>(data_->size_);
@ -742,8 +742,9 @@ struct RandomResizedCrop::Data {
int32_t max_attempts_; int32_t max_attempts_;
}; };
RandomResizedCrop::RandomResizedCrop(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio, RandomResizedCrop::RandomResizedCrop(const std::vector<int32_t> &size, const std::vector<float> &scale,
InterpolationMode interpolation, int32_t max_attempts) const std::vector<float> &ratio, InterpolationMode interpolation,
int32_t max_attempts)
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {} : data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
std::shared_ptr<TensorOperation> RandomResizedCrop::Parse() { std::shared_ptr<TensorOperation> RandomResizedCrop::Parse() {
@ -763,8 +764,8 @@ struct RandomResizedCropWithBBox::Data {
int32_t max_attempts_; int32_t max_attempts_;
}; };
RandomResizedCropWithBBox::RandomResizedCropWithBBox(std::vector<int32_t> size, std::vector<float> scale, RandomResizedCropWithBBox::RandomResizedCropWithBBox(const std::vector<int32_t> &size, const std::vector<float> &scale,
std::vector<float> ratio, InterpolationMode interpolation, const std::vector<float> &ratio, InterpolationMode interpolation,
int32_t max_attempts) int32_t max_attempts)
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {} : data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
@ -785,8 +786,8 @@ struct RandomRotation::Data {
std::vector<uint8_t> fill_value_; std::vector<uint8_t> fill_value_;
}; };
RandomRotation::RandomRotation(std::vector<float> degrees, InterpolationMode resample, bool expand, RandomRotation::RandomRotation(const std::vector<float> &degrees, InterpolationMode resample, bool expand,
std::vector<float> center, std::vector<uint8_t> fill_value) const std::vector<float> &center, const std::vector<uint8_t> &fill_value)
: data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {} : data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {}
std::shared_ptr<TensorOperation> RandomRotation::Parse() { std::shared_ptr<TensorOperation> RandomRotation::Parse() {
@ -857,7 +858,7 @@ struct RandomSharpness::Data {
std::vector<float> degrees_; std::vector<float> degrees_;
}; };
RandomSharpness::RandomSharpness(std::vector<float> degrees) : data_(std::make_shared<Data>(degrees)) {} RandomSharpness::RandomSharpness(const std::vector<float> &degrees) : data_(std::make_shared<Data>(degrees)) {}
std::shared_ptr<TensorOperation> RandomSharpness::Parse() { std::shared_ptr<TensorOperation> RandomSharpness::Parse() {
return std::make_shared<RandomSharpnessOperation>(data_->degrees_); return std::make_shared<RandomSharpnessOperation>(data_->degrees_);
@ -869,7 +870,7 @@ struct RandomSolarize::Data {
std::vector<uint8_t> threshold_; std::vector<uint8_t> threshold_;
}; };
RandomSolarize::RandomSolarize(std::vector<uint8_t> threshold) : data_(std::make_shared<Data>(threshold)) {} RandomSolarize::RandomSolarize(const std::vector<uint8_t> &threshold) : data_(std::make_shared<Data>(threshold)) {}
std::shared_ptr<TensorOperation> RandomSolarize::Parse() { std::shared_ptr<TensorOperation> RandomSolarize::Parse() {
return std::make_shared<RandomSolarizeOperation>(data_->threshold_); return std::make_shared<RandomSolarizeOperation>(data_->threshold_);
@ -921,7 +922,7 @@ struct Resize::Data {
InterpolationMode interpolation_; InterpolationMode interpolation_;
}; };
Resize::Resize(std::vector<int32_t> size, InterpolationMode interpolation) Resize::Resize(const std::vector<int32_t> &size, InterpolationMode interpolation)
: data_(std::make_shared<Data>(size, interpolation)) {} : data_(std::make_shared<Data>(size, interpolation)) {}
std::shared_ptr<TensorOperation> Resize::Parse() { std::shared_ptr<TensorOperation> Resize::Parse() {
@ -960,6 +961,40 @@ std::shared_ptr<TensorOperation> ResizePreserveAR::Parse() {
return std::make_shared<ResizePreserveAROperation>(data_->height_, data_->width_, data_->img_orientation_); return std::make_shared<ResizePreserveAROperation>(data_->height_, data_->width_, data_->img_orientation_);
} }
#ifndef ENABLE_ANDROID
// ResizeWithBBox Transform Operation.
struct ResizeWithBBox::Data {
Data(const std::vector<int32_t> &size, InterpolationMode interpolation)
: size_(size), interpolation_(interpolation) {}
std::vector<int32_t> size_;
InterpolationMode interpolation_;
};
ResizeWithBBox::ResizeWithBBox(const std::vector<int32_t> &size, InterpolationMode interpolation)
: data_(std::make_shared<Data>(size, interpolation)) {}
std::shared_ptr<TensorOperation> ResizeWithBBox::Parse() {
return std::make_shared<ResizeWithBBoxOperation>(data_->size_, data_->interpolation_);
}
#endif // not ENABLE_ANDROID
// RGB2BGR Transform Operation.
std::shared_ptr<TensorOperation> RGB2BGR::Parse() { return std::make_shared<RgbToBgrOperation>(); }
// RGB2GRAY Transform Operation.
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }
#ifndef ENABLE_ANDROID
// RgbaToBgr Transform Operation.
RGBA2BGR::RGBA2BGR() = default;
std::shared_ptr<TensorOperation> RGBA2BGR::Parse() { return std::make_shared<RgbaToBgrOperation>(); }
// RgbaToRgb Transform Operation.
RGBA2RGB::RGBA2RGB() = default;
std::shared_ptr<TensorOperation> RGBA2RGB::Parse() { return std::make_shared<RgbaToRgbOperation>(); }
// Rotate Transform Operation. // Rotate Transform Operation.
struct Rotate::Data { struct Rotate::Data {
Data(const float &degrees, InterpolationMode resample, bool expand, const std::vector<float> &center, Data(const float &degrees, InterpolationMode resample, bool expand, const std::vector<float> &center,
@ -977,8 +1012,8 @@ struct Rotate::Data {
Rotate::Rotate(FixRotationAngle angle_id) : data_(std::make_shared<Data>(angle_id)) {} Rotate::Rotate(FixRotationAngle angle_id) : data_(std::make_shared<Data>(angle_id)) {}
Rotate::Rotate(float degrees, InterpolationMode resample, bool expand, std::vector<float> center, Rotate::Rotate(float degrees, InterpolationMode resample, bool expand, const std::vector<float> &center,
std::vector<uint8_t> fill_value) const std::vector<uint8_t> &fill_value)
: data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {} : data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {}
std::shared_ptr<TensorOperation> Rotate::Parse() { std::shared_ptr<TensorOperation> Rotate::Parse() {
@ -997,40 +1032,6 @@ std::shared_ptr<TensorOperation> Rotate::Parse() {
return nullptr; return nullptr;
} }
#ifndef ENABLE_ANDROID
// ResizeWithBBox Transform Operation.
struct ResizeWithBBox::Data {
Data(const std::vector<int32_t> &size, InterpolationMode interpolation)
: size_(size), interpolation_(interpolation) {}
std::vector<int32_t> size_;
InterpolationMode interpolation_;
};
ResizeWithBBox::ResizeWithBBox(std::vector<int32_t> size, InterpolationMode interpolation)
: data_(std::make_shared<Data>(size, interpolation)) {}
std::shared_ptr<TensorOperation> ResizeWithBBox::Parse() {
return std::make_shared<ResizeWithBBoxOperation>(data_->size_, data_->interpolation_);
}
#endif // not ENABLE_ANDROID
// RGB2BGR Transform Operation.
std::shared_ptr<TensorOperation> RGB2BGR::Parse() { return std::make_shared<RgbToBgrOperation>(); }
// RGB2GRAY Transform Operation.
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }
#ifndef ENABLE_ANDROID
// RgbaToBgr Transform Operation.
RGBA2BGR::RGBA2BGR() {}
std::shared_ptr<TensorOperation> RGBA2BGR::Parse() { return std::make_shared<RgbaToBgrOperation>(); }
// RgbaToRgb Transform Operation.
RGBA2RGB::RGBA2RGB() {}
std::shared_ptr<TensorOperation> RGBA2RGB::Parse() { return std::make_shared<RgbaToRgbOperation>(); }
// SlicePatches Transform Operation. // SlicePatches Transform Operation.
struct SlicePatches::Data { struct SlicePatches::Data {
Data(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value) Data(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value)
@ -1060,9 +1061,10 @@ struct SoftDvppDecodeRandomCropResizeJpeg::Data {
int32_t max_attempts_; int32_t max_attempts_;
}; };
SoftDvppDecodeRandomCropResizeJpeg::SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size, SoftDvppDecodeRandomCropResizeJpeg::SoftDvppDecodeRandomCropResizeJpeg(const std::vector<int32_t> &size,
std::vector<float> scale, const std::vector<float> &scale,
std::vector<float> ratio, int32_t max_attempts) const std::vector<float> &ratio,
int32_t max_attempts)
: data_(std::make_shared<Data>(size, scale, ratio, max_attempts)) {} : data_(std::make_shared<Data>(size, scale, ratio, max_attempts)) {}
std::shared_ptr<TensorOperation> SoftDvppDecodeRandomCropResizeJpeg::Parse() { std::shared_ptr<TensorOperation> SoftDvppDecodeRandomCropResizeJpeg::Parse() {
@ -1076,14 +1078,15 @@ struct SoftDvppDecodeResizeJpeg::Data {
std::vector<int32_t> size_; std::vector<int32_t> size_;
}; };
SoftDvppDecodeResizeJpeg::SoftDvppDecodeResizeJpeg(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {} SoftDvppDecodeResizeJpeg::SoftDvppDecodeResizeJpeg(const std::vector<int32_t> &size)
: data_(std::make_shared<Data>(size)) {}
std::shared_ptr<TensorOperation> SoftDvppDecodeResizeJpeg::Parse() { std::shared_ptr<TensorOperation> SoftDvppDecodeResizeJpeg::Parse() {
return std::make_shared<SoftDvppDecodeResizeJpegOperation>(data_->size_); return std::make_shared<SoftDvppDecodeResizeJpegOperation>(data_->size_);
} }
// SwapRedBlue Transform Operation. // SwapRedBlue Transform Operation.
SwapRedBlue::SwapRedBlue() {} SwapRedBlue::SwapRedBlue() = default;
std::shared_ptr<TensorOperation> SwapRedBlue::Parse() { return std::make_shared<SwapRedBlueOperation>(); } std::shared_ptr<TensorOperation> SwapRedBlue::Parse() { return std::make_shared<SwapRedBlueOperation>(); }
@ -1104,7 +1107,7 @@ UniformAugment::UniformAugment(const std::vector<TensorTransform *> &transforms,
UniformAugment::UniformAugment(const std::vector<std::shared_ptr<TensorTransform>> &transforms, int32_t num_ops) UniformAugment::UniformAugment(const std::vector<std::shared_ptr<TensorTransform>> &transforms, int32_t num_ops)
: data_(std::make_shared<Data>()) { : data_(std::make_shared<Data>()) {
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
[](const std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> { [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
return op ? op->Parse() : nullptr; return op ? op->Parse() : nullptr;
}); });
data_->num_ops_ = num_ops; data_->num_ops_ = num_ops;
@ -1122,11 +1125,10 @@ std::shared_ptr<TensorOperation> UniformAugment::Parse() {
} }
// VerticalFlip Transform Operation. // VerticalFlip Transform Operation.
VerticalFlip::VerticalFlip() {} VerticalFlip::VerticalFlip() = default;
std::shared_ptr<TensorOperation> VerticalFlip::Parse() { return std::make_shared<VerticalFlipOperation>(); } std::shared_ptr<TensorOperation> VerticalFlip::Parse() { return std::make_shared<VerticalFlipOperation>(); }
#endif // not ENABLE_ANDROID #endif // not ENABLE_ANDROID
} // namespace vision } // namespace vision
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,7 +34,7 @@ constexpr char kAllpassBiquadOperation[] = "AllpassBiquad";
class AllpassBiquadOperation : public TensorOperation { class AllpassBiquadOperation : public TensorOperation {
public: public:
explicit AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q); AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q);
~AllpassBiquadOperation() = default; ~AllpassBiquadOperation() = default;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kAmplitudeToDBOperation[] = "AmplitudeToDB"; constexpr char kAmplitudeToDBOperation[] = "AmplitudeToDB";
class AmplitudeToDBOperation : public TensorOperation { class AmplitudeToDBOperation : public TensorOperation {
@ -51,7 +50,6 @@ class AmplitudeToDBOperation : public TensorOperation {
float amin_; float amin_;
float top_db_; float top_db_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,7 +22,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
// AngleOperation // AngleOperation
AngleOperation::AngleOperation() {} AngleOperation::AngleOperation() = default;
Status AngleOperation::ValidateParams() { return Status::OK(); } Status AngleOperation::ValidateParams() { return Status::OK(); }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,7 +30,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kAngleOperation[] = "Angle"; constexpr char kAngleOperation[] = "Angle";
class AngleOperation : public TensorOperation { class AngleOperation : public TensorOperation {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,7 @@ BandBiquadOperation::BandBiquadOperation(int32_t sample_rate, float central_freq
Status BandBiquadOperation::ValidateParams() { Status BandBiquadOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalar("BandBiquad", "Q", Q_, {0, 1.0}, true, false)); RETURN_IF_NOT_OK(ValidateScalar("BandBiquad", "Q", Q_, {0, 1.0}, true, false));
RETURN_IF_NOT_OK(ValidateScalarNotZero("BandBIquad", "sample_rate", sample_rate_)); RETURN_IF_NOT_OK(ValidateScalarNotZero("BandBiquad", "sample_rate", sample_rate_));
return Status::OK(); return Status::OK();
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,12 +30,11 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kBandBiquadOperation[] = "BandBiquad"; constexpr char kBandBiquadOperation[] = "BandBiquad";
class BandBiquadOperation : public TensorOperation { class BandBiquadOperation : public TensorOperation {
public: public:
explicit BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise); BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise);
~BandBiquadOperation() = default; ~BandBiquadOperation() = default;
@ -53,7 +52,6 @@ class BandBiquadOperation : public TensorOperation {
float Q_; float Q_;
bool noise_; bool noise_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,12 +30,11 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kBandpassBiquadOperation[] = "BandpassBiquad"; constexpr char kBandpassBiquadOperation[] = "BandpassBiquad";
class BandpassBiquadOperation : public TensorOperation { class BandpassBiquadOperation : public TensorOperation {
public: public:
explicit BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain); BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain);
~BandpassBiquadOperation() = default; ~BandpassBiquadOperation() = default;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,12 +30,11 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kBandrejectBiquadOperation[] = "BandrejectBiquad"; constexpr char kBandrejectBiquadOperation[] = "BandrejectBiquad";
class BandrejectBiquadOperation : public TensorOperation { class BandrejectBiquadOperation : public TensorOperation {
public: public:
explicit BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q); BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q);
~BandrejectBiquadOperation() = default; ~BandrejectBiquadOperation() = default;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,12 +30,11 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kBassBiquadOperation[] = "BassBiquad"; constexpr char kBassBiquadOperation[] = "BassBiquad";
class BassBiquadOperation : public TensorOperation { class BassBiquadOperation : public TensorOperation {
public: public:
explicit BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q); BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q);
~BassBiquadOperation() = default; ~BassBiquadOperation() = default;
@ -53,7 +52,6 @@ class BassBiquadOperation : public TensorOperation {
float central_freq_; float central_freq_;
float Q_; float Q_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,9 +29,9 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kBiquadOperation[] = "Biquad"; constexpr char kBiquadOperation[] = "Biquad";
class BiquadOperation : public TensorOperation { class BiquadOperation : public TensorOperation {
public: public:
BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2); BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_
@ -26,7 +27,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kComplexNormOperation[] = "ComplexNorm"; constexpr char kComplexNormOperation[] = "ComplexNorm";
class ComplexNormOperation : public TensorOperation { class ComplexNormOperation : public TensorOperation {
@ -46,7 +46,6 @@ class ComplexNormOperation : public TensorOperation {
private: private:
float power_; float power_;
}; // class ComplexNormOperation }; // class ComplexNormOperation
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kContrastOperation[] = "Contrast"; constexpr char kContrastOperation[] = "Contrast";
class ContrastOperation : public TensorOperation { class ContrastOperation : public TensorOperation {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kDCShiftOperation[] = "DCShift"; constexpr char kDCShiftOperation[] = "DCShift";
class DCShiftOperation : public TensorOperation { class DCShiftOperation : public TensorOperation {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,9 +28,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kDeemphBiquadOperation[] = "DeemphBiquad"; constexpr char kDeemphBiquadOperation[] = "DeemphBiquad";
class DeemphBiquadOperation : public TensorOperation { class DeemphBiquadOperation : public TensorOperation {
@ -51,7 +49,6 @@ class DeemphBiquadOperation : public TensorOperation {
int32_t sample_rate_; int32_t sample_rate_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,7 +17,6 @@
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
#include "minddata/dataset/audio/ir/validators.h" #include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/audio/kernels/equalizer_biquad_op.h" #include "minddata/dataset/audio/kernels/equalizer_biquad_op.h"
namespace mindspore { namespace mindspore {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -35,7 +35,7 @@ std::shared_ptr<TensorOp> FadeOperation::Build() {
return tensor_op; return tensor_op;
} }
Status FadeOperation::to_json(nlohmann::json *const out_json) { Status FadeOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["fade_in_len"] = fade_in_len_; args["fade_in_len"] = fade_in_len_;
args["fade_out_len"] = fade_out_len_; args["fade_out_len"] = fade_out_len_;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -27,12 +27,11 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kFadeOperation[] = "Fade"; constexpr char kFadeOperation[] = "Fade";
class FadeOperation : public TensorOperation { class FadeOperation : public TensorOperation {
public: public:
explicit FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape); FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape);
~FadeOperation() = default; ~FadeOperation() = default;
@ -45,7 +44,7 @@ class FadeOperation : public TensorOperation {
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *const out_json) override; Status to_json(nlohmann::json *out_json) override;
private: private:
int32_t fade_in_len_; int32_t fade_in_len_;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,8 +15,9 @@
*/ */
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
#include "minddata/dataset/audio/kernels/frequency_masking_op.h"
#include "minddata/dataset/audio/ir/validators.h" #include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/frequency_masking_op.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kFrequencyMaskingOperation[] = "FrequencyMasking"; constexpr char kFrequencyMaskingOperation[] = "FrequencyMasking";
class FrequencyMaskingOperation : public TensorOperation { class FrequencyMaskingOperation : public TensorOperation {
@ -49,7 +48,6 @@ class FrequencyMaskingOperation : public TensorOperation {
bool iid_masks_; bool iid_masks_;
float mask_value_; float mask_value_;
}; // class FrequencyMaskingOperation }; // class FrequencyMaskingOperation
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,9 +29,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kHighpassBiquadOperation[] = "HighpassBiquad"; constexpr char kHighpassBiquadOperation[] = "HighpassBiquad";
class HighpassBiquadOperation : public TensorOperation { class HighpassBiquadOperation : public TensorOperation {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
// Char arrays storing name of corresponding classes (in alphabetical order) // Char arrays storing name of corresponding classes (in alphabetical order)
constexpr char kLFilterOperation[] = "LFilter"; constexpr char kLFilterOperation[] = "LFilter";
@ -53,7 +52,6 @@ class LFilterOperation : public TensorOperation {
bool clamp_; bool clamp_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -52,7 +52,6 @@ class LowpassBiquadOperation : public TensorOperation {
float cutoff_freq_; float cutoff_freq_;
float Q_; float Q_;
}; // class LowpassBiquad }; // class LowpassBiquad
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -27,9 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kMagphaseOperation[] = "Magphase"; constexpr char kMagphaseOperation[] = "Magphase";
class MagphaseOperation : public TensorOperation { class MagphaseOperation : public TensorOperation {
@ -51,7 +49,6 @@ class MagphaseOperation : public TensorOperation {
private: private:
float power_; float power_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
#include "minddata/dataset/audio/ir/validators.h" #include "minddata/dataset/audio/ir/validators.h"

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kMuLawDecodingOperation[] = "MuLawDecoding"; constexpr char kMuLawDecodingOperation[] = "MuLawDecoding";
class MuLawDecodingOperation : public TensorOperation { class MuLawDecodingOperation : public TensorOperation {
@ -46,9 +45,7 @@ class MuLawDecodingOperation : public TensorOperation {
private: private:
int32_t quantization_channels_; int32_t quantization_channels_;
}; // class MuLawDecodingOperation }; // class MuLawDecodingOperation
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kMuLawEncodingOperation[] = "MuLawEncoding"; constexpr char kMuLawEncodingOperation[] = "MuLawEncoding";
class MuLawEncodingOperation : public TensorOperation { class MuLawEncodingOperation : public TensorOperation {
@ -49,5 +48,4 @@ class MuLawEncodingOperation : public TensorOperation {
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,7 +30,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kRiaaBiquadOperation[] = "RiaaBiquad"; constexpr char kRiaaBiquadOperation[] = "RiaaBiquad";
class RiaaBiquadOperation : public TensorOperation { class RiaaBiquadOperation : public TensorOperation {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kTimeMaskingOperation[] = "TimeMasking"; constexpr char kTimeMaskingOperation[] = "TimeMasking";
class TimeMaskingOperation : public TensorOperation { class TimeMaskingOperation : public TensorOperation {
@ -49,7 +48,6 @@ class TimeMaskingOperation : public TensorOperation {
int32_t mask_start_; int32_t mask_start_;
float mask_value_; float mask_value_;
}; // class TimeMaskingOperation }; // class TimeMaskingOperation
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kTimeStretchOperation[] = "TimeStretch"; constexpr char kTimeStretchOperation[] = "TimeStretch";
class TimeStretchOperation : public TensorOperation { class TimeStretchOperation : public TensorOperation {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -27,7 +27,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace audio { namespace audio {
constexpr char kVolOperation[] = "Vol"; constexpr char kVolOperation[] = "Vol";
class VolOperation : public TensorOperation { class VolOperation : public TensorOperation {
@ -48,7 +47,6 @@ class VolOperation : public TensorOperation {
float gain_; float gain_;
GainType gain_type_; GainType gain_type_;
}; };
} // namespace audio } // namespace audio
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -56,7 +56,7 @@ class EnWik9Op : public TextFileOp {
/// \brief DatasetName name getter. /// \brief DatasetName name getter.
/// \param[in] upper A bool to control if you need upper DatasetName. /// \param[in] upper A bool to control if you need upper DatasetName.
/// \return DatasetName of the current Op. /// \return DatasetName of the current Op.
virtual std::string DatasetName(bool upper = false) const { return upper ? "EnWik9" : "enwik9"; } std::string DatasetName(bool upper = false) const override { return upper ? "EnWik9" : "enwik9"; }
/// \brief Reads a text file and loads the data into multiple TensorRows. /// \brief Reads a text file and loads the data into multiple TensorRows.
/// \param[in] file The file to read. /// \param[in] file The file to read.
@ -70,7 +70,7 @@ class EnWik9Op : public TextFileOp {
/// \brief Count number of rows in each file. /// \brief Count number of rows in each file.
/// \param[in] file Txt file name. /// \param[in] file Txt file name.
/// \return int64_t The total number of rows in file. /// \return int64_t The total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file) override;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,8 +29,8 @@ namespace mindspore {
namespace dataset { namespace dataset {
// Constructor for EnWik9Node // Constructor for EnWik9Node
EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, std::shared_ptr<DatasetCache> cache) int32_t shard_id, const std::shared_ptr<DatasetCache> &cache)
: NonMappableSourceNode(std::move(cache)), : NonMappableSourceNode(cache),
num_samples_(num_samples), num_samples_(num_samples),
shuffle_(shuffle), shuffle_(shuffle),
num_shards_(num_shards), num_shards_(num_shards),

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -37,7 +37,7 @@ class EnWik9Node : public NonMappableSourceNode {
/// \param[in] shard_id The id of shard. /// \param[in] shard_id The id of shard.
/// \param[in] cache Tensor cache to use. /// \param[in] cache Tensor cache to use.
EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, std::shared_ptr<DatasetCache> cache); int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor. /// \brief Destructor.
~EnWik9Node() = default; ~EnWik9Node() = default;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,12 +31,10 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class TensorOperation; class TensorOperation;
// Transform operations for performing computer audio. // Transform operations for performing computer audio.
namespace audio { namespace audio {
/// \brief Compute the angle of complex tensor input. /// \brief Compute the angle of complex tensor input.
class MS_API Angle final : public TensorTransform { class MS_API Angle final : public TensorTransform {
public: public:
@ -51,29 +49,6 @@ class MS_API Angle final : public TensorTransform {
std::shared_ptr<TensorOperation> Parse() override; std::shared_ptr<TensorOperation> Parse() override;
}; };
/// \brief Design two-pole band filter.
class MS_API BandBiquad final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
/// \param[in] central_freq Central frequency (in Hz).
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
/// \param[in] noise Choose alternate mode for un-pitched audio or mode oriented to pitched audio(Default: False).
explicit BandBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool noise = false);
/// \brief Destructor.
~BandBiquad() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Design two-pole allpass filter. Similar to SoX implementation. /// \brief Design two-pole allpass filter. Similar to SoX implementation.
class MS_API AllpassBiquad final : public TensorTransform { class MS_API AllpassBiquad final : public TensorTransform {
public: public:
@ -121,6 +96,29 @@ class MS_API AmplitudeToDB final : public TensorTransform {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
/// \brief Design two-pole band filter.
class MS_API BandBiquad final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
/// \param[in] central_freq Central frequency (in Hz).
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
/// \param[in] noise Choose alternate mode for un-pitched audio or mode oriented to pitched audio(Default: False).
explicit BandBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool noise = false);
/// \brief Destructor.
~BandBiquad() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Design two-pole band-pass filter. /// \brief Design two-pole band-pass filter.
class MS_API BandpassBiquad final : public TensorTransform { class MS_API BandpassBiquad final : public TensorTransform {
public: public:
@ -561,7 +559,7 @@ class MS_API LFilter final : public TensorTransform {
/// Lower delays coefficients are first, e.g. [b0, b1, b2, ...]. /// Lower delays coefficients are first, e.g. [b0, b1, b2, ...].
/// Must be same size as a_coeffs (pad with 0's as necessary). /// Must be same size as a_coeffs (pad with 0's as necessary).
/// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1] (Default: True). /// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1] (Default: True).
explicit LFilter(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp = true); explicit LFilter(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp = true);
/// \brief Destructor. /// \brief Destructor.
~LFilter() = default; ~LFilter() = default;
@ -695,8 +693,8 @@ class MS_API Phaser final : public TensorTransform {
/// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2] (Default=0.5). /// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2] (Default=0.5).
/// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments). /// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments).
/// If false, use triangular modulation (gives single instruments a sharper phasing effect) (Default=true). /// If false, use triangular modulation (gives single instruments a sharper phasing effect) (Default=true).
Phaser(int32_t sample_rate, float gain_in = 0.4f, float gain_out = 0.74f, float delay_ms = 3.0f, float decay = 0.4f, explicit Phaser(int32_t sample_rate, float gain_in = 0.4f, float gain_out = 0.74f, float delay_ms = 3.0f,
float mod_speed = 0.5f, bool sinusoidal = true); float decay = 0.4f, float mod_speed = 0.5f, bool sinusoidal = true);
/// \brief Destructor. /// \brief Destructor.
~Phaser() = default; ~Phaser() = default;
@ -770,8 +768,8 @@ class MS_API SpectralCentroid : public TensorTransform {
/// \param[in] window Window function that is applied/multiplied to each frame/window, /// \param[in] window Window function that is applied/multiplied to each frame/window,
/// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming, /// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming,
/// WindowType::kHann or WindowType::kKaiser (Default: WindowType::kHann). /// WindowType::kHann or WindowType::kKaiser (Default: WindowType::kHann).
SpectralCentroid(int sample_rate, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, explicit SpectralCentroid(int32_t sample_rate, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0,
int32_t pad = 0, WindowType window = WindowType::kHann); int32_t pad = 0, WindowType window = WindowType::kHann);
~SpectralCentroid() = default; ~SpectralCentroid() = default;
@ -807,9 +805,9 @@ class MS_API Spectrogram : public TensorTransform {
/// \param[in] center Whether to pad waveform on both sides (Default: true). /// \param[in] center Whether to pad waveform on both sides (Default: true).
/// \param[in] pad_mode Controls the padding method used when center is true (Default: BorderType::kReflect). /// \param[in] pad_mode Controls the padding method used when center is true (Default: BorderType::kReflect).
/// \param[in] onesided Controls whether to return half of results to avoid redundancy (Default: true). /// \param[in] onesided Controls whether to return half of results to avoid redundancy (Default: true).
Spectrogram(int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0, explicit Spectrogram(int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0,
WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, bool center = true, WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false,
BorderType pad_mode = BorderType::kReflect, bool onesided = true); bool center = true, BorderType pad_mode = BorderType::kReflect, bool onesided = true);
/// \brief Destructor. /// \brief Destructor.
~Spectrogram() = default; ~Spectrogram() = default;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,14 +20,14 @@
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <vector> #include <vector>
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Config operations for setting and getting the configuration. // Config operations for setting and getting the configuration.
namespace config { namespace config {
/// \brief A function to set the seed to be used in any random generator. This is used to produce deterministic results. /// \brief A function to set the seed to be used in any random generator. This is used to produce deterministic results.
/// \param[in] seed The default seed to be used. /// \param[in] seed The default seed to be used.
/// \return The seed is set successfully or not. /// \return The seed is set successfully or not.
@ -155,10 +155,8 @@ bool MS_API load(const std::vector<char> &file);
/// std::string config_file = "/path/to/config/file"; /// std::string config_file = "/path/to/config/file";
/// bool rc = config::load(config_file); /// bool rc = config::load(config_file);
/// \endcode /// \endcode
inline bool MS_API load(std::string file) { return load(StringToChar(file)); } inline bool MS_API load(const std::string &file) { return load(StringToChar(file)); }
} // namespace config } // namespace config
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_CONFIG_H #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_CONFIG_H

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@ namespace dataset {
class MS_API DataHelper { class MS_API DataHelper {
public: public:
/// \brief constructor /// \brief constructor
DataHelper() {} DataHelper() = default;
/// \brief Destructor /// \brief Destructor
~DataHelper() = default; ~DataHelper() = default;

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,22 +36,22 @@ class MS_API Execute {
/// \param[in] op TensorOperation to be applied in Eager mode, it accepts operation in type of shared pointer. /// \param[in] op TensorOperation to be applied in Eager mode, it accepts operation in type of shared pointer.
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU). /// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0). /// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
explicit Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type = MapTargetDevice::kCpu, explicit Execute(const std::shared_ptr<TensorOperation> &op, MapTargetDevice device_type = MapTargetDevice::kCpu,
uint32_t device_id = 0); uint32_t device_id = 0);
/// \brief Constructor. /// \brief Constructor.
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of shared pointer. /// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of shared pointer.
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU). /// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0). /// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
explicit Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type = MapTargetDevice::kCpu, explicit Execute(const std::shared_ptr<TensorTransform> &op, MapTargetDevice device_type = MapTargetDevice::kCpu,
uint32_t device_id = 0); uint32_t device_id = 0);
/// \brief Constructor. /// \brief Constructor.
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of reference. /// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of reference.
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU). /// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0). /// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
explicit Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type = MapTargetDevice::kCpu, explicit Execute(const std::reference_wrapper<TensorTransform> &op,
uint32_t device_id = 0); MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
/// \brief Constructor. /// \brief Constructor.
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of raw pointer. /// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of raw pointer.
@ -64,7 +64,7 @@ class MS_API Execute {
/// in type of shared pointer. /// in type of shared pointer.
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU). /// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0). /// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
explicit Execute(std::vector<std::shared_ptr<TensorOperation>> ops, explicit Execute(const std::vector<std::shared_ptr<TensorOperation>> &ops,
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0); MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
/// \brief Constructor. /// \brief Constructor.
@ -72,7 +72,7 @@ class MS_API Execute {
/// in type of shared pointer. /// in type of shared pointer.
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU). /// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0). /// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
explicit Execute(std::vector<std::shared_ptr<TensorTransform>> ops, explicit Execute(const std::vector<std::shared_ptr<TensorTransform>> &ops,
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0); MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
/// \brief Constructor. /// \brief Constructor.
@ -80,7 +80,7 @@ class MS_API Execute {
/// in type of raw pointer. /// in type of raw pointer.
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU). /// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0). /// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
explicit Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, explicit Execute(const std::vector<std::reference_wrapper<TensorTransform>> &ops,
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0); MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
/// \brief Constructor. /// \brief Constructor.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,13 +22,13 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Forward declare // Forward declare
class ExecutionTree; class ExecutionTree;
class DatasetOp; class DatasetOp;
@ -57,7 +57,7 @@ class MS_API Iterator {
/// \param[in] ds The last DatasetOp in the dataset pipeline. /// \param[in] ds The last DatasetOp in the dataset pipeline.
/// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs). /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs).
/// \return Status error code, returns OK if no error encountered. /// \return Status error code, returns OK if no error encountered.
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs); Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs);
/// \brief Function to get the next row from the data pipeline. /// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a unordered_map(with column name). /// \note Type of return data is a unordered_map(with column name).
@ -185,7 +185,7 @@ class MS_API PullIterator : public Iterator {
/// \note Consider making this function protected. /// \note Consider making this function protected.
/// \param[in] ds The root node that calls the function. /// \param[in] ds The root node that calls the function.
/// \return Status error code, returns OK if no error encountered. /// \return Status error code, returns OK if no error encountered.
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds); Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds);
private: private:
std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_; std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,7 +24,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Forward declare // Forward declare
class SamplerObj; class SamplerObj;
@ -72,14 +71,14 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
public: public:
/// \brief Constructor /// \brief Constructor
Sampler() {} Sampler() = default;
/// \brief Destructor /// \brief Destructor
~Sampler() = default; ~Sampler() = default;
/// \brief A virtual function to add a child sampler. /// \brief A virtual function to add a child sampler.
/// \param[in] child The child sampler to be added as a children of this sampler. /// \param[in] child The child sampler to be added as a children of this sampler.
virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); } virtual void AddChild(const std::shared_ptr<Sampler> &child) { children_.push_back(child); }
protected: protected:
/// \brief Pure virtual function to convert a Sampler class into an IR Sampler object. /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
@ -238,7 +237,7 @@ class MS_API SubsetSampler : public Sampler {
/// std::string folder_path = "/path/to/image/folder"; /// std::string folder_path = "/path/to/image/folder";
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetSampler>({0, 2, 5})); /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetSampler>({0, 2, 5}));
/// \endcode /// \endcode
explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0); explicit SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
/// \brief Destructor. /// \brief Destructor.
~SubsetSampler() = default; ~SubsetSampler() = default;
@ -267,7 +266,7 @@ class MS_API SubsetRandomSampler final : public SubsetSampler {
/// std::string folder_path = "/path/to/image/folder"; /// std::string folder_path = "/path/to/image/folder";
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>({2, 7})); /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>({2, 7}));
/// \endcode /// \endcode
explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0); explicit SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
/// \brief Destructor. /// \brief Destructor.
~SubsetRandomSampler() = default; ~SubsetRandomSampler() = default;
@ -297,7 +296,7 @@ class MS_API WeightedRandomSampler final : public Sampler {
/// std::string folder_path = "/path/to/image/folder"; /// std::string folder_path = "/path/to/image/folder";
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler); /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
/// \endcode /// \endcode
explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true); explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true);
/// \brief Destructor. /// \brief Destructor.
~WeightedRandomSampler() = default; ~WeightedRandomSampler() = default;
@ -312,7 +311,6 @@ class MS_API WeightedRandomSampler final : public Sampler {
int64_t num_samples_; int64_t num_samples_;
bool replacement_; bool replacement_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,7 +30,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class SentencePieceVocab; class SentencePieceVocab;
class TensorOperation; class TensorOperation;
class Vectors; class Vectors;
@ -63,7 +62,7 @@ class MS_API BasicTokenizer final : public TensorTransform {
/// {"text"}); // input columns /// {"text"}); // input columns
/// \endcode /// \endcode
explicit BasicTokenizer(bool lower_case = false, bool keep_whitespace = false, explicit BasicTokenizer(bool lower_case = false, bool keep_whitespace = false,
const NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true, NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
bool with_offsets = false); bool with_offsets = false);
/// \brief Destructor /// \brief Destructor
@ -136,8 +135,7 @@ class MS_API BertTokenizer final : public TensorTransform {
/// \param[in] with_offsets Whether to output offsets of tokens (default=false). /// \param[in] with_offsets Whether to output offsets of tokens (default=false).
BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator, BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool lower_case, int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool lower_case,
bool keep_whitespace, const NormalizeForm normalize_form, bool preserve_unused_token, bool keep_whitespace, NormalizeForm normalize_form, bool preserve_unused_token, bool with_offsets);
bool with_offsets);
/// \brief Destructor /// \brief Destructor
~BertTokenizer() = default; ~BertTokenizer() = default;
@ -448,7 +446,7 @@ class MS_API RegexReplace final : public TensorTransform {
/// dataset = dataset->Map({regex_op}, // operations /// dataset = dataset->Map({regex_op}, // operations
/// {"text"}); // input columns /// {"text"}); // input columns
/// \endcode /// \endcode
RegexReplace(std::string pattern, std::string replace, bool replace_all = true) RegexReplace(const std::string &pattern, const std::string &replace, bool replace_all = true)
: RegexReplace(StringToChar(pattern), StringToChar(replace), replace_all) {} : RegexReplace(StringToChar(pattern), StringToChar(replace), replace_all) {}
/// \brief Constructor. /// \brief Constructor.
@ -489,7 +487,8 @@ class MS_API RegexTokenizer final : public TensorTransform {
/// dataset = dataset->Map({regex_op}, // operations /// dataset = dataset->Map({regex_op}, // operations
/// {"text"}); // input columns /// {"text"}); // input columns
/// \endcode /// \endcode
explicit RegexTokenizer(std::string delim_pattern, std::string keep_delim_pattern = "", bool with_offsets = false) explicit RegexTokenizer(const std::string &delim_pattern, const std::string &keep_delim_pattern = "",
bool with_offsets = false)
: RegexTokenizer(StringToChar(delim_pattern), StringToChar(keep_delim_pattern), with_offsets) {} : RegexTokenizer(StringToChar(delim_pattern), StringToChar(keep_delim_pattern), with_offsets) {}
explicit RegexTokenizer(const std::vector<char> &delim_pattern, const std::vector<char> &keep_delim_pattern, explicit RegexTokenizer(const std::vector<char> &delim_pattern, const std::vector<char> &keep_delim_pattern,
@ -581,7 +580,7 @@ class MS_API SlidingWindow final : public TensorTransform {
/// dataset = dataset->Map({slidingwindow_op}, // operations /// dataset = dataset->Map({slidingwindow_op}, // operations
/// {"text"}); // input columns /// {"text"}); // input columns
/// \endcode /// \endcode
explicit SlidingWindow(const int32_t width, const int32_t axis = 0); explicit SlidingWindow(int32_t width, int32_t axis = 0);
/// \brief Destructor /// \brief Destructor
~SlidingWindow() = default; ~SlidingWindow() = default;
@ -637,7 +636,7 @@ class MS_API ToVectors final : public TensorTransform {
/// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`. /// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`.
/// (default={}, means to initialize with zero vectors). /// (default={}, means to initialize with zero vectors).
/// \param[in] lower_case_backup Whether to look up the token in the lower case (default=false). /// \param[in] lower_case_backup Whether to look up the token in the lower case (default=false).
explicit ToVectors(const std::shared_ptr<Vectors> &vectors, std::vector<float> unk_init = {}, explicit ToVectors(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init = {},
bool lower_case_backup = false); bool lower_case_backup = false);
/// \brief Destructor /// \brief Destructor

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,7 +29,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class TensorOperation; class TensorOperation;
// We need the following two groups of forward declaration to friend the class in class TensorTransform. // We need the following two groups of forward declaration to friend the class in class TensorTransform.
@ -60,7 +59,7 @@ class MS_API TensorTransform : public std::enable_shared_from_this<TensorTransfo
public: public:
/// \brief Constructor /// \brief Constructor
TensorTransform() {} TensorTransform() = default;
/// \brief Destructor /// \brief Destructor
~TensorTransform() = default; ~TensorTransform() = default;
@ -108,10 +107,13 @@ class MS_API SliceOption {
public: public:
/// \param[in] all Slice the whole dimension /// \param[in] all Slice the whole dimension
explicit SliceOption(bool all) : all_(all) {} explicit SliceOption(bool all) : all_(all) {}
/// \param[in] indices Slice these indices along the dimension. Negative indices are supported. /// \param[in] indices Slice these indices along the dimension. Negative indices are supported.
explicit SliceOption(std::vector<dsize_t> indices) : indices_(indices) {} explicit SliceOption(const std::vector<dsize_t> &indices) : indices_(indices) {}
/// \param[in] slice Slice the generated indices from the slice object along the dimension. /// \param[in] slice Slice the generated indices from the slice object along the dimension.
explicit SliceOption(Slice slice) : slice_(slice) {} explicit SliceOption(Slice slice) : slice_(slice) {}
SliceOption(SliceOption const &slice) = default; SliceOption(SliceOption const &slice) = default;
~SliceOption() = default; ~SliceOption() = default;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,12 +31,10 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class TensorOperation; class TensorOperation;
// Transform operations for performing computer vision. // Transform operations for performing computer vision.
namespace vision { namespace vision {
/// \brief AdjustGamma TensorTransform. /// \brief AdjustGamma TensorTransform.
/// \note Apply gamma correction on input image. /// \note Apply gamma correction on input image.
class MS_API AdjustGamma final : public TensorTransform { class MS_API AdjustGamma final : public TensorTransform {
@ -49,10 +47,10 @@ class MS_API AdjustGamma final : public TensorTransform {
/// \code /// \code
/// /* Define operations */ /// /* Define operations */
/// auto decode_op = vision::Decode(); /// auto decode_op = vision::Decode();
/// auto adjustgamma_op = vision::AdjustGamma(10.0); /// auto adjust_gamma_op = vision::AdjustGamma(10.0);
/// ///
/// /* dataset is an instance of Dataset object */ /// /* dataset is an instance of Dataset object */
/// dataset = dataset->Map({decode_op, adjustgamma_op}, // operations /// dataset = dataset->Map({decode_op, adjust_gamma_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit AdjustGamma(float gamma, float gain = 1); explicit AdjustGamma(float gamma, float gain = 1);
@ -93,9 +91,9 @@ class MS_API AutoAugment final : public TensorTransform {
/// dataset = dataset->Map({decode_op, auto_augment_op}, // operations /// dataset = dataset->Map({decode_op, auto_augment_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
AutoAugment(AutoAugmentPolicy policy = AutoAugmentPolicy::kImageNet, explicit AutoAugment(AutoAugmentPolicy policy = AutoAugmentPolicy::kImageNet,
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
const std::vector<uint8_t> &fill_value = {0, 0, 0}); const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Destructor. /// \brief Destructor.
~AutoAugment() = default; ~AutoAugment() = default;
@ -126,7 +124,7 @@ class MS_API AutoContrast final : public TensorTransform {
/// dataset = dataset->Map({decode_op, autocontrast_op}, // operations /// dataset = dataset->Map({decode_op, autocontrast_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit AutoContrast(float cutoff = 0.0, std::vector<uint32_t> ignore = {}); explicit AutoContrast(float cutoff = 0.0, const std::vector<uint32_t> &ignore = {});
/// \brief Destructor. /// \brief Destructor.
~AutoContrast() = default; ~AutoContrast() = default;
@ -188,7 +186,7 @@ class MS_API BoundingBoxAugment final : public TensorTransform {
/// dataset = dataset->Map({bbox_aug_op}, // operations /// dataset = dataset->Map({bbox_aug_op}, // operations
/// {"image", "bbox"}); // input columns /// {"image", "bbox"}); // input columns
/// \endcode /// \endcode
explicit BoundingBoxAugment(const std::reference_wrapper<TensorTransform> transform, float ratio = 0.3); explicit BoundingBoxAugment(const std::reference_wrapper<TensorTransform> &transform, float ratio = 0.3);
/// \brief Destructor. /// \brief Destructor.
~BoundingBoxAugment() = default; ~BoundingBoxAugment() = default;
@ -473,7 +471,7 @@ class MS_API Pad final : public TensorTransform {
/// dataset = dataset->Map({decode_op, pad_op}, // operations /// dataset = dataset->Map({decode_op, pad_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0}, explicit Pad(const std::vector<int32_t> &padding, const std::vector<uint8_t> &fill_value = {0},
BorderType padding_mode = BorderType::kConstant); BorderType padding_mode = BorderType::kConstant);
/// \brief Destructor. /// \brief Destructor.
@ -509,7 +507,7 @@ class MS_API RandomAutoContrast final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_auto_contrast_op}, // operations /// dataset = dataset->Map({decode_op, random_auto_contrast_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomAutoContrast(float cutoff = 0.0, std::vector<uint32_t> ignore = {}, float prob = 0.5); explicit RandomAutoContrast(float cutoff = 0.0, const std::vector<uint32_t> &ignore = {}, float prob = 0.5);
/// \brief Destructor. /// \brief Destructor.
~RandomAutoContrast() = default; ~RandomAutoContrast() = default;
@ -612,8 +610,10 @@ class MS_API RandomColorAdjust final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_color_adjust_op}, // operations /// dataset = dataset->Map({decode_op, random_color_adjust_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomColorAdjust(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0}, explicit RandomColorAdjust(const std::vector<float> &brightness = {1.0, 1.0},
std::vector<float> saturation = {1.0, 1.0}, std::vector<float> hue = {0.0, 0.0}); const std::vector<float> &contrast = {1.0, 1.0},
const std::vector<float> &saturation = {1.0, 1.0},
const std::vector<float> &hue = {0.0, 0.0});
/// \brief Destructor. /// \brief Destructor.
~RandomColorAdjust() = default; ~RandomColorAdjust() = default;
@ -663,8 +663,8 @@ class MS_API RandomCrop final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_crop_op}, // operations /// dataset = dataset->Map({decode_op, random_crop_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0}, explicit RandomCrop(const std::vector<int32_t> &size, const std::vector<int32_t> &padding = {0, 0, 0, 0},
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0}, bool pad_if_needed = false, const std::vector<uint8_t> &fill_value = {0, 0, 0},
BorderType padding_mode = BorderType::kConstant); BorderType padding_mode = BorderType::kConstant);
/// \brief Destructor. /// \brief Destructor.
@ -708,8 +708,8 @@ class MS_API RandomCropDecodeResize final : public TensorTransform {
/// dataset = dataset->Map({random_op}, // operations /// dataset = dataset->Map({random_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomCropDecodeResize(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, explicit RandomCropDecodeResize(const std::vector<int32_t> &size, const std::vector<float> &scale = {0.08, 1.0},
std::vector<float> ratio = {3. / 4, 4. / 3}, const std::vector<float> &ratio = {3. / 4, 4. / 3},
InterpolationMode interpolation = InterpolationMode::kLinear, InterpolationMode interpolation = InterpolationMode::kLinear,
int32_t max_attempts = 10); int32_t max_attempts = 10);
@ -760,8 +760,8 @@ class MS_API RandomCropWithBBox final : public TensorTransform {
/// dataset = dataset->Map({random_op}, // operations /// dataset = dataset->Map({random_op}, // operations
/// {"image", "bbox"}); // input columns /// {"image", "bbox"}); // input columns
/// \endcode /// \endcode
explicit RandomCropWithBBox(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0}, explicit RandomCropWithBBox(const std::vector<int32_t> &size, const std::vector<int32_t> &padding = {0, 0, 0, 0},
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0}, bool pad_if_needed = false, const std::vector<uint8_t> &fill_value = {0, 0, 0},
BorderType padding_mode = BorderType::kConstant); BorderType padding_mode = BorderType::kConstant);
/// \brief Destructor. /// \brief Destructor.
@ -976,7 +976,7 @@ class MS_API RandomResize final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_op}, // operations /// dataset = dataset->Map({decode_op, random_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomResize(std::vector<int32_t> size); explicit RandomResize(const std::vector<int32_t> &size);
/// \brief Destructor. /// \brief Destructor.
~RandomResize() = default; ~RandomResize() = default;
@ -1008,7 +1008,7 @@ class MS_API RandomResizeWithBBox final : public TensorTransform {
/// dataset = dataset->Map({random_op}, // operations /// dataset = dataset->Map({random_op}, // operations
/// {"image", "bbox"}); // input columns /// {"image", "bbox"}); // input columns
/// \endcode /// \endcode
explicit RandomResizeWithBBox(std::vector<int32_t> size); explicit RandomResizeWithBBox(const std::vector<int32_t> &size);
/// \brief Destructor. /// \brief Destructor.
~RandomResizeWithBBox() = default; ~RandomResizeWithBBox() = default;
@ -1053,8 +1053,8 @@ class MS_API RandomResizedCrop final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_op}, // operations /// dataset = dataset->Map({decode_op, random_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomResizedCrop(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, explicit RandomResizedCrop(const std::vector<int32_t> &size, const std::vector<float> &scale = {0.08, 1.0},
std::vector<float> ratio = {3. / 4., 4. / 3.}, const std::vector<float> &ratio = {3. / 4., 4. / 3.},
InterpolationMode interpolation = InterpolationMode::kLinear, int32_t max_attempts = 10); InterpolationMode interpolation = InterpolationMode::kLinear, int32_t max_attempts = 10);
/// \brief Destructor. /// \brief Destructor.
@ -1100,9 +1100,10 @@ class MS_API RandomResizedCropWithBBox final : public TensorTransform {
/// dataset = dataset->Map({random_op}, // operations /// dataset = dataset->Map({random_op}, // operations
/// {"image", "bbox"}); // input columns /// {"image", "bbox"}); // input columns
/// \endcode /// \endcode
RandomResizedCropWithBBox(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, explicit RandomResizedCropWithBBox(const std::vector<int32_t> &size, const std::vector<float> &scale = {0.08, 1.0},
std::vector<float> ratio = {3. / 4., 4. / 3.}, const std::vector<float> &ratio = {3. / 4., 4. / 3.},
InterpolationMode interpolation = InterpolationMode::kLinear, int32_t max_attempts = 10); InterpolationMode interpolation = InterpolationMode::kLinear,
int32_t max_attempts = 10);
/// \brief Destructor. /// \brief Destructor.
~RandomResizedCropWithBBox() = default; ~RandomResizedCropWithBBox() = default;
@ -1144,8 +1145,9 @@ class MS_API RandomRotation final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_op}, // operations /// dataset = dataset->Map({decode_op, random_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
RandomRotation(std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, explicit RandomRotation(const std::vector<float> &degrees,
bool expand = false, std::vector<float> center = {}, std::vector<uint8_t> fill_value = {0, 0, 0}); InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
const std::vector<float> &center = {}, const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Destructor. /// \brief Destructor.
~RandomRotation() = default; ~RandomRotation() = default;
@ -1255,7 +1257,7 @@ class MS_API RandomSharpness final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_op}, // operations /// dataset = dataset->Map({decode_op, random_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomSharpness(std::vector<float> degrees = {0.1, 1.9}); explicit RandomSharpness(const std::vector<float> &degrees = {0.1, 1.9});
/// \brief Destructor. /// \brief Destructor.
~RandomSharpness() = default; ~RandomSharpness() = default;
@ -1287,7 +1289,7 @@ class MS_API RandomSolarize final : public TensorTransform {
/// dataset = dataset->Map({decode_op, random_op}, // operations /// dataset = dataset->Map({decode_op, random_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit RandomSolarize(std::vector<uint8_t> threshold = {0, 255}); explicit RandomSolarize(const std::vector<uint8_t> &threshold = {0, 255});
/// \brief Destructor. /// \brief Destructor.
~RandomSolarize() = default; ~RandomSolarize() = default;
@ -1414,7 +1416,8 @@ class MS_API ResizeWithBBox final : public TensorTransform {
/// dataset = dataset->Map({random_op}, // operations /// dataset = dataset->Map({random_op}, // operations
/// {"image", "bbox"}); // input columns /// {"image", "bbox"}); // input columns
/// \endcode /// \endcode
explicit ResizeWithBBox(std::vector<int32_t> size, InterpolationMode interpolation = InterpolationMode::kLinear); explicit ResizeWithBBox(const std::vector<int32_t> &size,
InterpolationMode interpolation = InterpolationMode::kLinear);
/// \brief Destructor. /// \brief Destructor.
~ResizeWithBBox() = default; ~ResizeWithBBox() = default;
@ -1500,8 +1503,8 @@ class MS_API SlicePatches final : public TensorTransform {
/// dataset = dataset->Map({decode_op, slice_patch_op}, // operations /// dataset = dataset->Map({decode_op, slice_patch_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
SlicePatches(int32_t num_height = 1, int32_t num_width = 1, SliceMode slice_mode = SliceMode::kPad, explicit SlicePatches(int32_t num_height = 1, int32_t num_width = 1, SliceMode slice_mode = SliceMode::kPad,
uint8_t fill_value = 0); uint8_t fill_value = 0);
/// \brief Destructor. /// \brief Destructor.
~SlicePatches() = default; ~SlicePatches() = default;
@ -1542,8 +1545,10 @@ class MS_API SoftDvppDecodeRandomCropResizeJpeg final : public TensorTransform {
/// dataset = dataset->Map({dvpp_op}, // operations /// dataset = dataset->Map({dvpp_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, explicit SoftDvppDecodeRandomCropResizeJpeg(const std::vector<int32_t> &size,
std::vector<float> ratio = {3. / 4., 4. / 3.}, int32_t max_attempts = 10); const std::vector<float> &scale = {0.08, 1.0},
const std::vector<float> &ratio = {3. / 4., 4. / 3.},
int32_t max_attempts = 10);
/// \brief Destructor. /// \brief Destructor.
~SoftDvppDecodeRandomCropResizeJpeg() = default; ~SoftDvppDecodeRandomCropResizeJpeg() = default;
@ -1581,7 +1586,7 @@ class MS_API SoftDvppDecodeResizeJpeg final : public TensorTransform {
/// dataset = dataset->Map({dvpp_op}, // operations /// dataset = dataset->Map({dvpp_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit SoftDvppDecodeResizeJpeg(std::vector<int32_t> size); explicit SoftDvppDecodeResizeJpeg(const std::vector<int32_t> &size);
/// \brief Destructor. /// \brief Destructor.
~SoftDvppDecodeResizeJpeg() = default; ~SoftDvppDecodeResizeJpeg() = default;
@ -1712,7 +1717,6 @@ class MS_API VerticalFlip final : public TensorTransform {
/// \return Shared pointer to TensorOperation object. /// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override; std::shared_ptr<TensorOperation> Parse() override;
}; };
} // namespace vision } // namespace vision
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,16 +22,15 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "include/api/status.h" #include "include/api/status.h"
#include "include/dataset/constants.h" #include "include/dataset/constants.h"
#include "include/dataset/transforms.h" #include "include/dataset/transforms.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Transform operations for performing computer vision. // Transform operations for performing computer vision.
namespace vision { namespace vision {
/* ##################################### API class ###########################################*/ /* ##################################### API class ###########################################*/
/// \brief Decode and resize JPEG image using the hardware algorithm of /// \brief Decode and resize JPEG image using the hardware algorithm of
@ -49,7 +48,7 @@ class MS_API DvppDecodeResizeJpeg final : public TensorTransform {
/// dataset = dataset->Map({dvpp_op}, // operations /// dataset = dataset->Map({dvpp_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit DvppDecodeResizeJpeg(std::vector<uint32_t> resize); explicit DvppDecodeResizeJpeg(const std::vector<uint32_t> &resize);
/// \brief Destructor. /// \brief Destructor.
~DvppDecodeResizeJpeg() = default; ~DvppDecodeResizeJpeg() = default;
@ -82,7 +81,7 @@ class MS_API DvppDecodeResizeCropJpeg final : public TensorTransform {
/// dataset = dataset->Map({dvpp_op}, // operations /// dataset = dataset->Map({dvpp_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
DvppDecodeResizeCropJpeg(std::vector<uint32_t> crop, std::vector<uint32_t> resize); DvppDecodeResizeCropJpeg(const std::vector<uint32_t> &crop, const std::vector<uint32_t> &resize);
/// \brief Destructor. /// \brief Destructor.
~DvppDecodeResizeCropJpeg() = default; ~DvppDecodeResizeCropJpeg() = default;
@ -125,7 +124,6 @@ class MS_API DvppDecodePng final : public TensorTransform {
std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) override; std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) override;
}; };
} // namespace vision } // namespace vision
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,16 +22,15 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "include/api/status.h" #include "include/api/status.h"
#include "include/dataset/constants.h" #include "include/dataset/constants.h"
#include "include/dataset/transforms.h" #include "include/dataset/transforms.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Transform operations for performing computer vision. // Transform operations for performing computer vision.
namespace vision { namespace vision {
// Forward Declarations // Forward Declarations
class RotateOperation; class RotateOperation;
@ -96,7 +95,7 @@ class MS_API CenterCrop final : public TensorTransform {
/// dataset = dataset->Map({decode_op, crop_op}, // operations /// dataset = dataset->Map({decode_op, crop_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit CenterCrop(std::vector<int32_t> size); explicit CenterCrop(const std::vector<int32_t> &size);
/// \brief Destructor. /// \brief Destructor.
~CenterCrop() = default; ~CenterCrop() = default;
@ -131,7 +130,7 @@ class MS_API Crop final : public TensorTransform {
/// dataset = dataset->Map({decode_op, crop_op}, // operations /// dataset = dataset->Map({decode_op, crop_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size); Crop(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size);
/// \brief Destructor. /// \brief Destructor.
~Crop() = default; ~Crop() = default;
@ -194,7 +193,7 @@ class MS_API GaussianBlur final : public TensorTransform {
/// dataset = dataset->Map({decode_op, gaussian_op}, // operations /// dataset = dataset->Map({decode_op, gaussian_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma = {0., 0.}); explicit GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma = {0., 0.});
/// \brief Destructor. /// \brief Destructor.
~GaussianBlur() = default; ~GaussianBlur() = default;
@ -227,7 +226,7 @@ class MS_API Normalize final : public TensorTransform {
/// dataset = dataset->Map({decode_op, normalize_op}, // operations /// dataset = dataset->Map({decode_op, normalize_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
Normalize(std::vector<float> mean, std::vector<float> std); Normalize(const std::vector<float> &mean, const std::vector<float> &std);
/// \brief Destructor. /// \brief Destructor.
~Normalize() = default; ~Normalize() = default;
@ -318,7 +317,7 @@ class MS_API Resize final : public TensorTransform {
/// dataset = dataset->Map({decode_op, resize_op}, // operations /// dataset = dataset->Map({decode_op, resize_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
explicit Resize(std::vector<int32_t> size, InterpolationMode interpolation = InterpolationMode::kLinear); explicit Resize(const std::vector<int32_t> &size, InterpolationMode interpolation = InterpolationMode::kLinear);
/// \brief Destructor. /// \brief Destructor.
~Resize() = default; ~Resize() = default;
@ -477,8 +476,8 @@ class MS_API Rotate final : public TensorTransform {
/// dataset = dataset->Map({decode_op, rotate_op}, // operations /// dataset = dataset->Map({decode_op, rotate_op}, // operations
/// {"image"}); // input columns /// {"image"}); // input columns
/// \endcode /// \endcode
Rotate(float degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, explicit Rotate(float degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
std::vector<float> center = {}, std::vector<uint8_t> fill_value = {0, 0, 0}); const std::vector<float> &center = {}, const std::vector<uint8_t> &fill_value = {0, 0, 0});
/// \brief Destructor. /// \brief Destructor.
~Rotate() = default; ~Rotate() = default;
@ -493,7 +492,6 @@ class MS_API Rotate final : public TensorTransform {
struct Data; struct Data;
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
} // namespace vision } // namespace vision
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
@ -38,7 +38,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class Tensor; class Tensor;
class TensorShape; class TensorShape;
class TreeAdapter; class TreeAdapter;
@ -132,7 +131,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// std::unordered_map<std::string, mindspore::MSTensor> row; /// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row); /// iter->GetNextRow(&row);
/// \endcode /// \endcode
std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {}); std::shared_ptr<PullIterator> CreatePullBasedIterator(const std::vector<std::vector<char>> &columns = {});
/// \brief Function to create an Iterator over the Dataset pipeline /// \brief Function to create an Iterator over the Dataset pipeline
/// \param[in] columns List of columns to be used to specify the order of columns /// \param[in] columns List of columns to be used to specify the order of columns
@ -146,7 +145,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// std::unordered_map<std::string, mindspore::MSTensor> row; /// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row); /// iter->GetNextRow(&row);
/// \endcode /// \endcode
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) { std::shared_ptr<Iterator> CreateIterator(const std::vector<std::string> &columns = {}, int32_t num_epochs = -1) {
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs); return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
} }
@ -162,7 +161,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
/// of data or not(default=false). /// of data or not(default=false).
/// \return Returns true if no error encountered else false. /// \return Returns true if no error encountered else false.
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0, bool DeviceQueue(const std::string &queue_name = "", const std::string &device_type = "", int32_t device_id = 0,
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0, int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
bool create_data_info_queue = false) { bool create_data_info_queue = false) {
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end, return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
@ -189,7 +188,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// std::string save_file = "Cifar10Data.mindrecord"; /// std::string save_file = "Cifar10Data.mindrecord";
/// bool rc = ds->Save(save_file); /// bool rc = ds->Save(save_file);
/// \endcode /// \endcode
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") { bool Save(const std::string &dataset_path, int32_t num_files = 1, const std::string &dataset_type = "mindrecord") {
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type)); return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
} }
@ -259,12 +258,12 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// // columns will remain the same. /// // columns will remain the same.
/// dataset = dataset->Map({decode_op, random_jitter_op}, {"image"}) /// dataset = dataset->Map({decode_op, random_jitter_op}, {"image"})
/// \endcode /// \endcode
std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations, std::shared_ptr<MapDataset> Map(const std::vector<TensorTransform *> &operations,
const std::vector<std::string> &input_columns = {}, const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {}, const std::vector<std::string> &output_columns = {},
const std::vector<std::string> &project_columns = {}, const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr, const std::shared_ptr<DatasetCache> &cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
std::vector<std::shared_ptr<TensorOperation>> transform_ops; std::vector<std::shared_ptr<TensorOperation>> transform_ops;
(void)std::transform( (void)std::transform(
operations.begin(), operations.end(), std::back_inserter(transform_ops), operations.begin(), operations.end(), std::back_inserter(transform_ops),
@ -290,15 +289,15 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// \param[in] project_columns A list of column names to project /// \param[in] project_columns A list of column names to project
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current MapDataset /// \return Shared pointer to the current MapDataset
std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations, std::shared_ptr<MapDataset> Map(const std::vector<std::shared_ptr<TensorTransform>> &operations,
const std::vector<std::string> &input_columns = {}, const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {}, const std::vector<std::string> &output_columns = {},
const std::vector<std::string> &project_columns = {}, const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr, const std::shared_ptr<DatasetCache> &cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
std::vector<std::shared_ptr<TensorOperation>> transform_ops; std::vector<std::shared_ptr<TensorOperation>> transform_ops;
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops), (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
[](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> { [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
return op != nullptr ? op->Parse() : nullptr; return op != nullptr ? op->Parse() : nullptr;
}); });
return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns), return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
@ -322,12 +321,12 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
/// \param[in] project_columns A list of column names to project /// \param[in] project_columns A list of column names to project
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current MapDataset /// \return Shared pointer to the current MapDataset
std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations, std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> &operations,
const std::vector<std::string> &input_columns = {}, const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {}, const std::vector<std::string> &output_columns = {},
const std::vector<std::string> &project_columns = {}, const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr, const std::shared_ptr<DatasetCache> &cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
std::vector<std::shared_ptr<TensorOperation>> transform_ops; std::vector<std::shared_ptr<TensorOperation>> transform_ops;
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops), (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
[](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); }); [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
@ -378,7 +377,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF(); std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
// Char interface(CharIF) of CreateIterator // Char interface(CharIF) of CreateIterator
std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs); std::shared_ptr<Iterator> CreateIteratorCharIF(const std::vector<std::vector<char>> &columns, int32_t num_epochs);
// Char interface(CharIF) of DeviceQueue // Char interface(CharIF) of DeviceQueue
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id, bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
@ -443,7 +442,7 @@ class MS_API SchemaObj {
std::string to_string() { return to_json(); } std::string to_string() { return to_json(); }
/// \brief Set a new value to dataset_type /// \brief Set a new value to dataset_type
void set_dataset_type(std::string dataset_type); void set_dataset_type(const std::string &dataset_type);
/// \brief Set a new value to num_rows /// \brief Set a new value to num_rows
void set_num_rows(int32_t num_rows); void set_num_rows(int32_t num_rows);
@ -492,28 +491,32 @@ class MS_API SchemaObj {
class MS_API BatchDataset : public Dataset { class MS_API BatchDataset : public Dataset {
public: public:
BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false); BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder = false);
~BatchDataset() = default; ~BatchDataset() = default;
}; };
class MS_API MapDataset : public Dataset { class MS_API MapDataset : public Dataset {
public: public:
MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, MapDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::shared_ptr<TensorOperation>> &operations,
const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns, const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache, const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
std::vector<std::shared_ptr<DSCallback>> callbacks); const std::vector<std::shared_ptr<DSCallback>> &callbacks);
~MapDataset() = default; ~MapDataset() = default;
}; };
class MS_API ProjectDataset : public Dataset { class MS_API ProjectDataset : public Dataset {
public: public:
ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns); ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns);
~ProjectDataset() = default; ~ProjectDataset() = default;
}; };
class MS_API ShuffleDataset : public Dataset { class MS_API ShuffleDataset : public Dataset {
public: public:
ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size); ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size);
~ShuffleDataset() = default; ~ShuffleDataset() = default;
}; };
@ -566,7 +569,7 @@ class MS_API AlbumDataset : public Dataset {
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used). /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema, AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
const std::vector<std::vector<char>> &column_names, bool decode, const std::vector<std::vector<char>> &column_names, bool decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache); const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of AlbumDataset. /// \brief Destructor of AlbumDataset.
~AlbumDataset() = default; ~AlbumDataset() = default;
@ -664,7 +667,7 @@ class MS_API MnistDataset : public Dataset {
/// \param[in] sampler Sampler object used to choose samples from the dataset. /// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used). /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache); const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// Destructor of MnistDataset. /// Destructor of MnistDataset.
~MnistDataset() = default; ~MnistDataset() = default;
@ -726,5 +729,4 @@ inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_

View File

@ -54,7 +54,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate44100) {
std::vector<int64_t> expected = {2, 200}; std::vector<int64_t> expected = {2, 200};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["waveform"]; auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);
@ -93,7 +93,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate48000) {
std::vector<int64_t> expected = {30, 40}; std::vector<int64_t> expected = {30, 40};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["waveform"]; auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);
@ -132,7 +132,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate88200) {
std::vector<int64_t> expected = {5, 4}; std::vector<int64_t> expected = {5, 4};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["waveform"]; auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);
@ -171,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate96000) {
std::vector<int64_t> expected = {2, 3}; std::vector<int64_t> expected = {2, 3};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["waveform"]; auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);
@ -221,7 +221,7 @@ TEST_F(MindDataTestPipeline, TestSlidingWindowCmn) {
std::unordered_map<std::string, mindspore::MSTensor> row; std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
i++; i++;
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
} }
@ -266,7 +266,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidBasic) {
std::shared_ptr<Dataset> ds = RandomData(8, schema); std::shared_ptr<Dataset> ds = RandomData(8, schema);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectral_centroid = audio::SpectralCentroid({44100, 8, 8, 4, 1, WindowType::kHann}); auto spectral_centroid = audio::SpectralCentroid(44100, 8, 8, 4, 1, WindowType::kHann);
auto ds1 = ds->Map({spectral_centroid}, {"waveform"}); auto ds1 = ds->Map({spectral_centroid}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -278,7 +278,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidBasic) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -297,7 +297,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidDefault) {
std::shared_ptr<Dataset> ds = RandomData(8, schema); std::shared_ptr<Dataset> ds = RandomData(8, schema);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectral_centroid = audio::SpectralCentroid({44100}); auto spectral_centroid = audio::SpectralCentroid(44100);
auto ds1 = ds->Map({spectral_centroid}, {"waveform"}); auto ds1 = ds->Map({spectral_centroid}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -309,7 +309,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidDefault) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -336,7 +336,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
// Check n_fft // Check n_fft
MS_LOG(INFO) << "n_fft is zero."; MS_LOG(INFO) << "n_fft is zero.";
auto spectral_centroid_op_1 = audio::SpectralCentroid({44100, 0, 8, 4, 1, WindowType::kHann}); auto spectral_centroid_op_1 = audio::SpectralCentroid(44100, 0, 8, 4, 1, WindowType::kHann);
ds01 = ds->Map({spectral_centroid_op_1}); ds01 = ds->Map({spectral_centroid_op_1});
EXPECT_NE(ds01, nullptr); EXPECT_NE(ds01, nullptr);
@ -345,7 +345,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
// Check win_length // Check win_length
MS_LOG(INFO) << "win_length is -1."; MS_LOG(INFO) << "win_length is -1.";
auto spectral_centroid_op_2 = audio::SpectralCentroid({44100, 8, -1, 4, 1, WindowType::kHann}); auto spectral_centroid_op_2 = audio::SpectralCentroid(44100, 8, -1, 4, 1, WindowType::kHann);
ds02 = ds->Map({spectral_centroid_op_2}); ds02 = ds->Map({spectral_centroid_op_2});
EXPECT_NE(ds02, nullptr); EXPECT_NE(ds02, nullptr);
@ -354,7 +354,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
// Check hop_length // Check hop_length
MS_LOG(INFO) << "hop_length is -1."; MS_LOG(INFO) << "hop_length is -1.";
auto spectral_centroid_op_3 = audio::SpectralCentroid({44100, 8, 8, -1, 1, WindowType::kHann}); auto spectral_centroid_op_3 = audio::SpectralCentroid(44100, 8, 8, -1, 1, WindowType::kHann);
ds03 = ds->Map({spectral_centroid_op_3}); ds03 = ds->Map({spectral_centroid_op_3});
EXPECT_NE(ds03, nullptr); EXPECT_NE(ds03, nullptr);
@ -363,7 +363,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
// Check pad // Check pad
MS_LOG(INFO) << "pad is -1."; MS_LOG(INFO) << "pad is -1.";
auto spectral_centroid_op_4 = audio::SpectralCentroid({44100, 8, 8, 4, -1, WindowType::kHann}); auto spectral_centroid_op_4 = audio::SpectralCentroid(44100, 8, 8, 4, -1, WindowType::kHann);
ds04 = ds->Map({spectral_centroid_op_4}); ds04 = ds->Map({spectral_centroid_op_4});
EXPECT_NE(ds04, nullptr); EXPECT_NE(ds04, nullptr);
@ -372,7 +372,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
// Check sample_rate // Check sample_rate
MS_LOG(INFO) << "sample_rate is -1."; MS_LOG(INFO) << "sample_rate is -1.";
auto spectral_centroid_op_5 = audio::SpectralCentroid({-1, 8, 8, 4, 8, WindowType::kHann}); auto spectral_centroid_op_5 = audio::SpectralCentroid(-1, 8, 8, 4, 8, WindowType::kHann);
ds05 = ds->Map({spectral_centroid_op_5}); ds05 = ds->Map({spectral_centroid_op_5});
EXPECT_NE(ds05, nullptr); EXPECT_NE(ds05, nullptr);
@ -392,7 +392,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramDefault) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -404,7 +404,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramDefault) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -424,7 +424,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramOnesidedFalse) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, false}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, false);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -436,7 +436,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramOnesidedFalse) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -456,7 +456,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramCenterFalse) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, false, false, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, false, false, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -468,7 +468,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramCenterFalse) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -488,7 +488,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNormalizedTrue) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, true, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, true, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -500,7 +500,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNormalizedTrue) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -520,7 +520,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWindowHamming) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -532,7 +532,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWindowHamming) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -552,7 +552,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPadmodeEdge) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kEdge, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kEdge, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -564,7 +564,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPadmodeEdge) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -584,7 +584,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPower0) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHamming, 0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHamming, 0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -596,7 +596,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPower0) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -616,7 +616,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNfft50) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({50, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(50, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -628,7 +628,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNfft50) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -648,7 +648,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPad10) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 20, 10, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 10, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -660,7 +660,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPad10) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -680,7 +680,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWinlength30) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 30, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 30, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -692,7 +692,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWinlength30) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -712,7 +712,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramHoplength30) {
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto spectrogram = auto spectrogram =
audio::Spectrogram({40, 40, 30, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 30, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
auto ds1 = ds->Map({spectrogram}, {"waveform"}); auto ds1 = ds->Map({spectrogram}, {"waveform"});
EXPECT_NE(ds1, nullptr); EXPECT_NE(ds1, nullptr);
@ -724,7 +724,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramHoplength30) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
ASSERT_OK(iter->GetNextRow(&row)); ASSERT_OK(iter->GetNextRow(&row));
i++; i++;
} }
@ -753,7 +753,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
// Check n_fft // Check n_fft
MS_LOG(INFO) << "n_fft is zero."; MS_LOG(INFO) << "n_fft is zero.";
auto spectrogram_op_01 = auto spectrogram_op_01 =
audio::Spectrogram({0, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(0, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
ds01 = ds->Map({spectrogram_op_01}); ds01 = ds->Map({spectrogram_op_01});
EXPECT_NE(ds01, nullptr); EXPECT_NE(ds01, nullptr);
@ -763,7 +763,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
// Check win_length // Check win_length
MS_LOG(INFO) << "win_length is -1."; MS_LOG(INFO) << "win_length is -1.";
auto spectrogram_op_02 = auto spectrogram_op_02 =
audio::Spectrogram({40, -1, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, -1, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
ds02 = ds->Map({spectrogram_op_02}); ds02 = ds->Map({spectrogram_op_02});
EXPECT_NE(ds02, nullptr); EXPECT_NE(ds02, nullptr);
@ -773,7 +773,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
// Check hop_length // Check hop_length
MS_LOG(INFO) << "hop_length is -1."; MS_LOG(INFO) << "hop_length is -1.";
auto spectrogram_op_03 = auto spectrogram_op_03 =
audio::Spectrogram({40, 40, -1, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, -1, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
ds03 = ds->Map({spectrogram_op_03}); ds03 = ds->Map({spectrogram_op_03});
EXPECT_NE(ds03, nullptr); EXPECT_NE(ds03, nullptr);
@ -783,7 +783,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
// Check power // Check power
MS_LOG(INFO) << "power is -1."; MS_LOG(INFO) << "power is -1.";
auto spectrogram_op_04 = auto spectrogram_op_04 =
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, -1, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, -1, false, true, BorderType::kReflect, true);
ds04 = ds->Map({spectrogram_op_04}); ds04 = ds->Map({spectrogram_op_04});
EXPECT_NE(ds04, nullptr); EXPECT_NE(ds04, nullptr);
@ -793,7 +793,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
// Check pad // Check pad
MS_LOG(INFO) << "pad is -1."; MS_LOG(INFO) << "pad is -1.";
auto spectrogram_op_05 = auto spectrogram_op_05 =
audio::Spectrogram({40, 40, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 40, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
ds05 = ds->Map({spectrogram_op_05}); ds05 = ds->Map({spectrogram_op_05});
EXPECT_NE(ds05, nullptr); EXPECT_NE(ds05, nullptr);
@ -803,7 +803,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
// Check n_fft and win)length // Check n_fft and win)length
MS_LOG(INFO) << "n_fft is 40, win_length is 50."; MS_LOG(INFO) << "n_fft is 40, win_length is 50.";
auto spectrogram_op_06 = auto spectrogram_op_06 =
audio::Spectrogram({40, 50, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true}); audio::Spectrogram(40, 50, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
ds06 = ds->Map({spectrogram_op_06}); ds06 = ds->Map({spectrogram_op_06});
EXPECT_NE(ds06, nullptr); EXPECT_NE(ds06, nullptr);
@ -837,7 +837,7 @@ TEST_F(MindDataTestPipeline, TestTimeMaskingPipeline) {
std::vector<int64_t> expected = {2, 200}; std::vector<int64_t> expected = {2, 200};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["inputData"]; auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);
@ -901,7 +901,7 @@ TEST_F(MindDataTestPipeline, TestTimeStretchPipeline) {
std::vector<int64_t> expected = {2, freq, static_cast<int64_t>(std::ceil(400 / rate)), 2}; std::vector<int64_t> expected = {2, freq, static_cast<int64_t>(std::ceil(400 / rate)), 2};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["inputData"]; auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32); ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
@ -965,7 +965,7 @@ TEST_F(MindDataTestPipeline, TestTrebleBiquadBasic) {
std::vector<int64_t> expected = {2, 200}; std::vector<int64_t> expected = {2, 200};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["waveform"]; auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);
@ -1032,7 +1032,7 @@ TEST_F(MindDataTestPipeline, TestVolPipeline) {
std::vector<int64_t> expected = {2, 200}; std::vector<int64_t> expected = {2, 200};
int i = 0; int i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto col = row["inputData"]; auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected); ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2); ASSERT_EQ(col.Shape().size(), 2);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.