update review problems
This commit is contained in:
parent
1ac5261dfe
commit
96a2e38d01
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -92,9 +92,10 @@ std::shared_ptr<TensorOperation> AmplitudeToDB::Parse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Angle Transform Operation.
|
// Angle Transform Operation.
|
||||||
Angle::Angle() {}
|
Angle::Angle() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Angle::Parse() { return std::make_shared<AngleOperation>(); }
|
std::shared_ptr<TensorOperation> Angle::Parse() { return std::make_shared<AngleOperation>(); }
|
||||||
|
|
||||||
// BandBiquad Transform Operation.
|
// BandBiquad Transform Operation.
|
||||||
struct BandBiquad::Data {
|
struct BandBiquad::Data {
|
||||||
Data(int32_t sample_rate, float central_freq, float Q, bool noise)
|
Data(int32_t sample_rate, float central_freq, float Q, bool noise)
|
||||||
|
@ -225,7 +226,7 @@ struct DBToAmplitude::Data {
|
||||||
float power_;
|
float power_;
|
||||||
};
|
};
|
||||||
|
|
||||||
DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared<Data>(power, power)) {}
|
DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared<Data>(ref, power)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> DBToAmplitude::Parse() {
|
std::shared_ptr<TensorOperation> DBToAmplitude::Parse() {
|
||||||
return std::make_shared<DBToAmplitudeOperation>(data_->ref_, data_->power_);
|
return std::make_shared<DBToAmplitudeOperation>(data_->ref_, data_->power_);
|
||||||
|
@ -431,7 +432,7 @@ struct LFilter::Data {
|
||||||
bool clamp_;
|
bool clamp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
LFilter::LFilter(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp)
|
LFilter::LFilter(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp)
|
||||||
: data_(std::make_shared<Data>(a_coeffs, b_coeffs, clamp)) {}
|
: data_(std::make_shared<Data>(a_coeffs, b_coeffs, clamp)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> LFilter::Parse() {
|
std::shared_ptr<TensorOperation> LFilter::Parse() {
|
||||||
|
@ -585,6 +586,17 @@ struct Spectrogram::Data {
|
||||||
bool onesided_;
|
bool onesided_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Spectrogram::Spectrogram(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window,
|
||||||
|
float power, bool normalized, bool center, BorderType pad_mode, bool onesided)
|
||||||
|
: data_(std::make_shared<Data>(n_fft, win_length, hop_length, pad, window, power, normalized, center, pad_mode,
|
||||||
|
onesided)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> Spectrogram::Parse() {
|
||||||
|
return std::make_shared<SpectrogramOperation>(data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_,
|
||||||
|
data_->window_, data_->power_, data_->normalized_, data_->center_,
|
||||||
|
data_->pad_mode_, data_->onesided_);
|
||||||
|
}
|
||||||
|
|
||||||
// SpectralCentroid Transform Operation.
|
// SpectralCentroid Transform Operation.
|
||||||
struct SpectralCentroid::Data {
|
struct SpectralCentroid::Data {
|
||||||
Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window)
|
Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window)
|
||||||
|
@ -611,17 +623,6 @@ std::shared_ptr<TensorOperation> SpectralCentroid::Parse() {
|
||||||
data_->hop_length_, data_->pad_, data_->window_);
|
data_->hop_length_, data_->pad_, data_->window_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Spectrogram::Spectrogram(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window,
|
|
||||||
float power, bool normalized, bool center, BorderType pad_mode, bool onesided)
|
|
||||||
: data_(std::make_shared<Data>(n_fft, win_length, hop_length, pad, window, power, normalized, center, pad_mode,
|
|
||||||
onesided)) {}
|
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Spectrogram::Parse() {
|
|
||||||
return std::make_shared<SpectrogramOperation>(data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_,
|
|
||||||
data_->window_, data_->power_, data_->normalized_, data_->center_,
|
|
||||||
data_->pad_mode_, data_->onesided_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TimeMasking Transform Operation.
|
// TimeMasking Transform Operation.
|
||||||
struct TimeMasking::Data {
|
struct TimeMasking::Data {
|
||||||
Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value)
|
Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,19 +13,17 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include "minddata/dataset/include/dataset/config.h"
|
||||||
|
|
||||||
#include "minddata/dataset/core/config_manager.h"
|
#include "minddata/dataset/core/config_manager.h"
|
||||||
#include "minddata/dataset/core/global_context.h"
|
#include "minddata/dataset/core/global_context.h"
|
||||||
#include "minddata/dataset/include/dataset/config.h"
|
|
||||||
#include "minddata/dataset/util/log_adapter.h"
|
#include "minddata/dataset/util/log_adapter.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Config operations for setting and getting the configuration.
|
// Config operations for setting and getting the configuration.
|
||||||
namespace config {
|
namespace config {
|
||||||
|
|
||||||
std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager();
|
std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager();
|
||||||
|
|
||||||
// Function to set the seed to be used in any random generator
|
// Function to set the seed to be used in any random generator
|
||||||
|
@ -102,7 +100,6 @@ bool load(const std::vector<char> &file) {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace config
|
} // namespace config
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -121,41 +121,49 @@ Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::ve
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint8_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint8_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int16_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int16_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint16_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint16_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int32_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int32_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint32_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint32_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int64_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int64_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint64_t &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint64_t &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
return jh.UpdateValue(CharToString(in_file), CharToString(key), value, CharToString(out_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const float &value,
|
Status DataHelper::UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const float &value,
|
||||||
const std::vector<char> &out_file) {
|
const std::vector<char> &out_file) {
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
|
@ -179,6 +187,5 @@ size_t DataHelper::DumpData(const unsigned char *tensor_addr, const size_t &tens
|
||||||
auto jh = JsonHelper();
|
auto jh = JsonHelper();
|
||||||
return jh.DumpData(tensor_addr, tensor_size, addr, buffer_size);
|
return jh.DumpData(tensor_addr, tensor_size, addr, buffer_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -15,33 +15,29 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "minddata/dataset/include/dataset/datasets.h"
|
#include "minddata/dataset/include/dataset/datasets.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
#include "minddata/dataset/core/tensor.h"
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/core/type_id.h"
|
||||||
|
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
|
||||||
|
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||||
#include "minddata/dataset/engine/runtime_context.h"
|
#include "minddata/dataset/engine/runtime_context.h"
|
||||||
#include "minddata/dataset/include/dataset/constants.h"
|
#include "minddata/dataset/include/dataset/constants.h"
|
||||||
#include "minddata/dataset/include/dataset/iterator.h"
|
#include "minddata/dataset/include/dataset/iterator.h"
|
||||||
#include "minddata/dataset/include/dataset/samplers.h"
|
#include "minddata/dataset/include/dataset/samplers.h"
|
||||||
#include "minddata/dataset/include/dataset/transforms.h"
|
|
||||||
#include "minddata/dataset/util/path.h"
|
|
||||||
#include "minddata/dataset/util/status.h"
|
|
||||||
#include "minddata/dataset/core/client.h"
|
|
||||||
#include "minddata/dataset/core/type_id.h"
|
|
||||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
|
||||||
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
|
|
||||||
|
|
||||||
#include "minddata/dataset/kernels/c_func_op.h"
|
#include "minddata/dataset/kernels/c_func_op.h"
|
||||||
#include "minddata/dataset/kernels/tensor_op.h"
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/path.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
|
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
#include "minddata/dataset/text/sentence_piece_vocab.h"
|
#include "minddata/dataset/text/sentence_piece_vocab.h"
|
||||||
#include "minddata/dataset/text/vocab.h"
|
#include "minddata/dataset/text/vocab.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -49,6 +45,7 @@
|
||||||
// Sampler headers (in alphabetical order)
|
// Sampler headers (in alphabetical order)
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||||
|
|
||||||
|
// IR dataset node
|
||||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||||
|
|
||||||
// IR non-leaf nodes
|
// IR non-leaf nodes
|
||||||
|
@ -60,17 +57,13 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||||
|
@ -78,22 +71,15 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "minddata/dataset/core/config_manager.h"
|
|
||||||
#include "minddata/dataset/util/random.h"
|
|
||||||
#include "minddata/dataset/util/services.h"
|
|
||||||
|
|
||||||
// IR leaf nodes
|
// IR leaf nodes
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
|
||||||
|
|
||||||
// IR leaf nodes disabled for android
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||||
|
@ -114,6 +100,9 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
||||||
|
@ -139,7 +128,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// convert MSTensorVec to DE TensorRow, return empty if fails
|
// convert MSTensorVec to DE TensorRow, return empty if fails
|
||||||
TensorRow VecToRow(const MSTensorVec &v) {
|
TensorRow VecToRow(const MSTensorVec &v) {
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
|
@ -160,19 +148,20 @@ TensorRow VecToRow(const MSTensorVec &v) {
|
||||||
MSTensorVec RowToVec(const TensorRow &v) {
|
MSTensorVec RowToVec(const TensorRow &v) {
|
||||||
MSTensorVec rv;
|
MSTensorVec rv;
|
||||||
rv.reserve(v.size());
|
rv.reserve(v.size());
|
||||||
std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> MSTensor {
|
std::transform(v.begin(), v.end(), std::back_inserter(rv), [](const std::shared_ptr<Tensor> &t) -> MSTensor {
|
||||||
return mindspore::MSTensor(std::make_shared<DETensor>(t));
|
return mindspore::MSTensor(std::make_shared<DETensor>(t));
|
||||||
});
|
});
|
||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper
|
// Convert a std::function<TensorRow(TensorRow)> to std::function<MSTensorVec(MSTensor)> with this helper
|
||||||
TensorRow FuncPtrConverter(std::function<MSTensorVec(MSTensorVec)> func, TensorRow in_row) {
|
TensorRow FuncPtrConverter(const std::function<MSTensorVec(MSTensorVec)> &func, const TensorRow &in_row) {
|
||||||
return VecToRow(func(RowToVec(in_row)));
|
return VecToRow(func(RowToVec(in_row)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to create the iterator, which will build and launch the execution tree.
|
// Function to create the iterator, which will build and launch the execution tree.
|
||||||
std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs) {
|
std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(const std::vector<std::vector<char>> &columns,
|
||||||
|
int32_t num_epochs) {
|
||||||
std::shared_ptr<Iterator> iter;
|
std::shared_ptr<Iterator> iter;
|
||||||
try {
|
try {
|
||||||
auto ds = shared_from_this();
|
auto ds = shared_from_this();
|
||||||
|
@ -198,7 +187,7 @@ std::shared_ptr<Iterator> Dataset::CreateIteratorCharIF(std::vector<std::vector<
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to create the iterator, which will build and launch the execution tree.
|
// Function to create the iterator, which will build and launch the execution tree.
|
||||||
std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(std::vector<std::vector<char>> columns) {
|
std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(const std::vector<std::vector<char>> &columns) {
|
||||||
// The specified columns will be selected from the dataset and passed down the pipeline
|
// The specified columns will be selected from the dataset and passed down the pipeline
|
||||||
// in the order specified, other columns will be discarded.
|
// in the order specified, other columns will be discarded.
|
||||||
// This code is not in a try/catch block because there is no execution tree class that will be created.
|
// This code is not in a try/catch block because there is no execution tree class that will be created.
|
||||||
|
@ -424,7 +413,7 @@ std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file) {
|
||||||
// (In alphabetical order)
|
// (In alphabetical order)
|
||||||
|
|
||||||
// Function to create a Batch dataset
|
// Function to create a Batch dataset
|
||||||
BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) {
|
BatchDataset::BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder) {
|
||||||
// Default values
|
// Default values
|
||||||
auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
|
auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -433,9 +422,9 @@ BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, b
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
// Function to create a BucketBatchByLength dataset
|
// Function to create a BucketBatchByLength dataset
|
||||||
BucketBatchByLengthDataset::BucketBatchByLengthDataset(
|
BucketBatchByLengthDataset::BucketBatchByLengthDataset(
|
||||||
std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &column_names,
|
const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &column_names,
|
||||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||||
std::function<MSTensorVec(MSTensorVec)> element_length_function,
|
const std::function<MSTensorVec(MSTensorVec)> &element_length_function,
|
||||||
const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary,
|
const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info, bool pad_to_bucket_boundary,
|
||||||
bool drop_remainder) {
|
bool drop_remainder) {
|
||||||
std::shared_ptr<TensorOp> c_func = nullptr;
|
std::shared_ptr<TensorOp> c_func = nullptr;
|
||||||
|
@ -467,7 +456,7 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(
|
||||||
ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||||
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
||||||
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
||||||
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
|
[](const std::shared_ptr<Dataset> &dataset) -> std::shared_ptr<DatasetNode> {
|
||||||
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
|
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -476,7 +465,8 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTensorVec(MSTensorVec)> predicate,
|
FilterDataset::FilterDataset(const std::shared_ptr<Dataset> &input,
|
||||||
|
const std::function<MSTensorVec(MSTensorVec)> &predicate,
|
||||||
const std::vector<std::vector<char>> &input_columns) {
|
const std::vector<std::vector<char>> &input_columns) {
|
||||||
std::shared_ptr<TensorOp> c_func = nullptr;
|
std::shared_ptr<TensorOp> c_func = nullptr;
|
||||||
if (predicate) {
|
if (predicate) {
|
||||||
|
@ -488,11 +478,13 @@ FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTen
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
|
MapDataset::MapDataset(const std::shared_ptr<Dataset> &input,
|
||||||
|
const std::vector<std::shared_ptr<TensorOperation>> &operations,
|
||||||
const std::vector<std::vector<char>> &input_columns,
|
const std::vector<std::vector<char>> &input_columns,
|
||||||
const std::vector<std::vector<char>> &output_columns,
|
const std::vector<std::vector<char>> &output_columns,
|
||||||
const std::vector<std::vector<char>> &project_columns,
|
const std::vector<std::vector<char>> &project_columns,
|
||||||
const std::shared_ptr<DatasetCache> &cache, std::vector<std::shared_ptr<DSCallback>> callbacks) {
|
const std::shared_ptr<DatasetCache> &cache,
|
||||||
|
const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
|
||||||
auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns),
|
auto ds = std::make_shared<MapNode>(input->IRNode(), operations, VectorCharToString(input_columns),
|
||||||
VectorCharToString(output_columns), VectorCharToString(project_columns), cache,
|
VectorCharToString(output_columns), VectorCharToString(project_columns), cache,
|
||||||
callbacks);
|
callbacks);
|
||||||
|
@ -500,13 +492,14 @@ MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_p
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns) {
|
ProjectDataset::ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns) {
|
||||||
auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns));
|
auto ds = std::make_shared<ProjectNode>(input->IRNode(), VectorCharToString(columns));
|
||||||
|
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
|
RenameDataset::RenameDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &input_columns,
|
||||||
const std::vector<std::vector<char>> &output_columns) {
|
const std::vector<std::vector<char>> &output_columns) {
|
||||||
auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns),
|
auto ds = std::make_shared<RenameNode>(input->IRNode(), VectorCharToString(input_columns),
|
||||||
VectorCharToString(output_columns));
|
VectorCharToString(output_columns));
|
||||||
|
@ -515,13 +508,13 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
RepeatDataset::RepeatDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
|
||||||
auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
|
auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
|
||||||
|
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) {
|
ShuffleDataset::ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size) {
|
||||||
// Pass in reshuffle_each_epoch with true
|
// Pass in reshuffle_each_epoch with true
|
||||||
auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
|
auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
|
||||||
|
|
||||||
|
@ -529,13 +522,13 @@ ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_si
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
SkipDataset::SkipDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
|
||||||
auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
|
auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
|
||||||
|
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
TakeDataset::TakeDataset(const std::shared_ptr<Dataset> &input, int32_t count) {
|
||||||
auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
|
auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
|
||||||
|
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -544,7 +537,7 @@ TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
||||||
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||||
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
|
||||||
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
(void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
|
||||||
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
|
[](const std::shared_ptr<Dataset> &dataset) -> std::shared_ptr<DatasetNode> {
|
||||||
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
|
return (dataset != nullptr) ? dataset->IRNode() : nullptr;
|
||||||
});
|
});
|
||||||
auto ds = std::make_shared<ZipNode>(all_datasets);
|
auto ds = std::make_shared<ZipNode>(all_datasets);
|
||||||
|
@ -552,6 +545,7 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int64_t Dataset::GetBatchSize() {
|
int64_t Dataset::GetBatchSize() {
|
||||||
int64_t batch_size = -1;
|
int64_t batch_size = -1;
|
||||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
|
@ -737,10 +731,11 @@ Status SchemaObj::add_column_char(const std::vector<char> &name, const std::vect
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
|
Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out_json);
|
||||||
nlohmann::json json_file;
|
nlohmann::json json_file;
|
||||||
json_file["columns"] = data_->columns_;
|
json_file["columns"] = data_->columns_;
|
||||||
std::string str_dataset_type_(data_->dataset_type_);
|
std::string str_dataset_type_(data_->dataset_type_);
|
||||||
if (str_dataset_type_ != "") {
|
if (!str_dataset_type_.empty()) {
|
||||||
json_file["datasetType"] = str_dataset_type_;
|
json_file["datasetType"] = str_dataset_type_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -751,13 +746,13 @@ Status SchemaObj::schema_to_json(nlohmann::json *out_json) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<char> SchemaObj::to_json_char() {
|
std::vector<char> SchemaObj::to_json_char() {
|
||||||
nlohmann::json json_file;
|
nlohmann::json json_file;
|
||||||
this->schema_to_json(&json_file);
|
this->schema_to_json(&json_file);
|
||||||
return StringToChar(json_file.dump(2));
|
return StringToChar(json_file.dump(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SchemaObj::set_dataset_type(std::string dataset_type) { data_->dataset_type_ = dataset_type.data(); }
|
void SchemaObj::set_dataset_type(const std::string &dataset_type) { data_->dataset_type_ = dataset_type; }
|
||||||
|
|
||||||
void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; }
|
void SchemaObj::set_num_rows(int32_t num_rows) { data_->num_rows_ = num_rows; }
|
||||||
|
|
||||||
|
@ -817,7 +812,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) {
|
||||||
for (const auto &it_child : json_obj.items()) {
|
for (const auto &it_child : json_obj.items()) {
|
||||||
if (it_child.key() == "datasetType") {
|
if (it_child.key() == "datasetType") {
|
||||||
std::string str_dataset_type_ = it_child.value();
|
std::string str_dataset_type_ = it_child.value();
|
||||||
data_->dataset_type_ = str_dataset_type_.data();
|
data_->dataset_type_ = str_dataset_type_;
|
||||||
} else if (it_child.key() == "numRows") {
|
} else if (it_child.key() == "numRows") {
|
||||||
data_->num_rows_ = it_child.value();
|
data_->num_rows_ = it_child.value();
|
||||||
} else if (it_child.key() == "columns") {
|
} else if (it_child.key() == "columns") {
|
||||||
|
@ -867,10 +862,10 @@ Status SchemaObj::ParseColumnStringCharIF(const std::vector<char> &json_string)
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
|
|
||||||
std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
|
std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
|
||||||
std::optional<std::vector<char>> hostname,
|
const std::optional<std::vector<char>> &hostname,
|
||||||
std::optional<int32_t> port,
|
const std::optional<int32_t> &port,
|
||||||
std::optional<int32_t> num_connections,
|
const std::optional<int32_t> &num_connections,
|
||||||
std::optional<int32_t> prefetch_sz) {
|
const std::optional<int32_t> &prefetch_sz) {
|
||||||
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
|
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
@ -901,9 +896,10 @@ AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vect
|
||||||
VectorCharToString(column_names), decode, sampler_obj, cache);
|
VectorCharToString(column_names), decode, sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
|
AlbumDataset::AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
|
||||||
const std::vector<std::vector<char>> &column_names, bool decode,
|
const std::vector<std::vector<char>> &column_names, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
|
auto ds = std::make_shared<AlbumNode>(CharToString(dataset_dir), CharToString(data_schema),
|
||||||
VectorCharToString(column_names), decode, sampler_obj, cache);
|
VectorCharToString(column_names), decode, sampler_obj, cache);
|
||||||
|
@ -935,7 +931,7 @@ Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool
|
||||||
}
|
}
|
||||||
|
|
||||||
Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode,
|
Caltech256Dataset::Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
auto ds = std::make_shared<Caltech256Node>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||||
|
@ -951,6 +947,7 @@ CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::ve
|
||||||
SetCharToString(extensions), cache);
|
SetCharToString(extensions), cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions,
|
const Sampler *sampler, bool decode, const std::set<std::vector<char>> &extensions,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
|
@ -959,8 +956,9 @@ CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::ve
|
||||||
SetCharToString(extensions), cache);
|
SetCharToString(extensions), cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
CelebADataset::CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler, bool decode,
|
const std::reference_wrapper<Sampler> &sampler, bool decode,
|
||||||
const std::set<std::vector<char>> &extensions,
|
const std::set<std::vector<char>> &extensions,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
|
@ -975,14 +973,16 @@ Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::
|
||||||
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||||
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
Cifar10Dataset::Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<Cifar10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
|
@ -995,14 +995,16 @@ Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std
|
||||||
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||||
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
Cifar100Dataset::Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<Cifar100Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
|
@ -1030,7 +1032,7 @@ CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const
|
||||||
|
|
||||||
CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
CityscapesDataset::CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
|
const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
|
auto ds = std::make_shared<CityscapesNode>(CharToString(dataset_dir), CharToString(usage), CharToString(quality_mode),
|
||||||
|
@ -1054,6 +1056,7 @@ CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector
|
||||||
decode, sampler_obj, cache, extra_metadata);
|
decode, sampler_obj, cache, extra_metadata);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||||
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
|
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
|
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
|
||||||
|
@ -1062,9 +1065,10 @@ CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector
|
||||||
decode, sampler_obj, cache, extra_metadata);
|
decode, sampler_obj, cache, extra_metadata);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||||
const std::vector<char> &task, const bool &decode,
|
const std::vector<char> &task, const bool &decode,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache,
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
|
||||||
const bool &extra_metadata) {
|
const bool &extra_metadata) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
|
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
|
||||||
|
@ -1118,7 +1122,7 @@ DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vect
|
||||||
|
|
||||||
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
DIV2KDataset::DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::vector<char> &downgrade, int32_t scale, bool decode,
|
const std::vector<char> &downgrade, int32_t scale, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
|
auto ds = std::make_shared<DIV2KNode>(CharToString(dataset_dir), CharToString(usage), CharToString(downgrade), scale,
|
||||||
decode, sampler_obj, cache);
|
decode, sampler_obj, cache);
|
||||||
|
@ -1144,7 +1148,7 @@ EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||||
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler,
|
const std::vector<char> &usage, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
auto ds = std::make_shared<EMnistNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
||||||
|
@ -1175,7 +1179,7 @@ FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t
|
||||||
}
|
}
|
||||||
|
|
||||||
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||||
int32_t base_seed, const std::reference_wrapper<Sampler> sampler,
|
int32_t base_seed, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
|
auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
|
||||||
|
@ -1198,7 +1202,7 @@ FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, c
|
||||||
}
|
}
|
||||||
|
|
||||||
FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
FashionMnistDataset::FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<FashionMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
|
@ -1223,7 +1227,7 @@ FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||||
bool decode, const std::reference_wrapper<Sampler> sampler,
|
bool decode, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds =
|
auto ds =
|
||||||
|
@ -1261,7 +1265,7 @@ ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, boo
|
||||||
}
|
}
|
||||||
|
|
||||||
ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
|
ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::set<std::vector<char>> &extensions,
|
const std::set<std::vector<char>> &extensions,
|
||||||
const std::map<std::vector<char>, int32_t> &class_indexing,
|
const std::map<std::vector<char>, int32_t> &class_indexing,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
|
@ -1292,7 +1296,7 @@ IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector
|
||||||
}
|
}
|
||||||
|
|
||||||
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
IMDBDataset::IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
// Create logical representation of IMDBDataset.
|
// Create logical representation of IMDBDataset.
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<IMDBNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
|
@ -1335,7 +1339,7 @@ KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<KMnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
|
@ -1356,7 +1360,7 @@ LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const Sam
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler,
|
LJSpeechDataset::LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
|
auto ds = std::make_shared<LJSpeechNode>(CharToString(dataset_dir), sampler_obj, cache);
|
||||||
|
@ -1372,6 +1376,7 @@ ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const st
|
||||||
MapCharToString(class_indexing), decode, cache);
|
MapCharToString(class_indexing), decode, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
|
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
|
||||||
const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
|
const Sampler *sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
|
||||||
bool decode, const std::shared_ptr<DatasetCache> &cache) {
|
bool decode, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
|
@ -1380,8 +1385,9 @@ ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const st
|
||||||
MapCharToString(class_indexing), decode, cache);
|
MapCharToString(class_indexing), decode, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
|
ManifestDataset::ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
|
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
|
@ -1421,7 +1427,7 @@ MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
|
||||||
|
|
||||||
MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
|
MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
|
||||||
const std::vector<std::vector<char>> &columns_list,
|
const std::vector<std::vector<char>> &columns_list,
|
||||||
const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
|
const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
|
||||||
int64_t num_padded, ShuffleMode shuffle_mode,
|
int64_t num_padded, ShuffleMode shuffle_mode,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
|
@ -1468,7 +1474,7 @@ MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_f
|
||||||
|
|
||||||
MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
|
MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
|
||||||
const std::vector<std::vector<char>> &columns_list,
|
const std::vector<std::vector<char>> &columns_list,
|
||||||
const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
|
const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
|
||||||
int64_t num_padded, ShuffleMode shuffle_mode,
|
int64_t num_padded, ShuffleMode shuffle_mode,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
|
@ -1488,14 +1494,16 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
|
||||||
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
|
MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||||
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<MnistNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -1529,7 +1537,7 @@ PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
PhotoTourDataset::PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name,
|
||||||
const std::vector<char> &usage, const std::reference_wrapper<Sampler> sampler,
|
const std::vector<char> &usage, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
auto ds = std::make_shared<PhotoTourNode>(CharToString(dataset_dir), CharToString(name), CharToString(usage),
|
||||||
|
@ -1556,7 +1564,7 @@ Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
Places365Dataset::Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const bool small, const bool decode, const std::reference_wrapper<Sampler> sampler,
|
const bool small, const bool decode, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds =
|
auto ds =
|
||||||
|
@ -1579,7 +1587,7 @@ QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
QMnistDataset::QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
|
auto ds = std::make_shared<QMnistNode>(CharToString(dataset_dir), CharToString(usage), compat, sampler_obj, cache);
|
||||||
|
@ -1650,7 +1658,7 @@ STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vect
|
||||||
}
|
}
|
||||||
|
|
||||||
STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
STL10Dataset::STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<STL10Node>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -1681,6 +1689,7 @@ VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<c
|
||||||
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
|
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
|
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
|
||||||
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
|
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
|
||||||
bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache,
|
bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache,
|
||||||
|
@ -1690,9 +1699,10 @@ VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<c
|
||||||
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
|
MapCharToString(class_indexing), decode, sampler_obj, cache, extra_metadata);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
|
VOCDataset::VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task,
|
||||||
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
|
const std::vector<char> &usage, const std::map<std::vector<char>, int32_t> &class_indexing,
|
||||||
bool decode, const std::reference_wrapper<Sampler> sampler,
|
bool decode, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) {
|
const std::shared_ptr<DatasetCache> &cache, bool extra_metadata) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
|
auto ds = std::make_shared<VOCNode>(CharToString(dataset_dir), CharToString(task), CharToString(usage),
|
||||||
|
@ -1710,13 +1720,14 @@ WikiTextDataset::WikiTextDataset(const std::vector<char> &dataset_dir, const std
|
||||||
|
|
||||||
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
|
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
|
||||||
const std::vector<std::vector<char>> &columns_list,
|
const std::vector<std::vector<char>> &columns_list,
|
||||||
std::shared_ptr<DatasetCache> cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache);
|
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), VectorCharToString(columns_list), cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
|
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
|
||||||
const std::vector<std::vector<char>> &columns_list,
|
const std::vector<std::vector<char>> &columns_list,
|
||||||
std::shared_ptr<DatasetCache> cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto ds =
|
auto ds =
|
||||||
std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache);
|
std::make_shared<RandomNode>(total_rows, CharToString(schema_path), VectorCharToString(columns_list), cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -1736,8 +1747,8 @@ SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
|
SBUDataset::SBUDataset(const std::vector<char> &dataset_dir, bool decode,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
auto ds = std::make_shared<SBUNode>(CharToString(dataset_dir), decode, sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -1767,7 +1778,7 @@ SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_di
|
||||||
}
|
}
|
||||||
|
|
||||||
SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
auto ds = std::make_shared<SpeechCommandsNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
|
||||||
|
@ -1777,16 +1788,18 @@ SpeechCommandsDataset::SpeechCommandsDataset(const std::vector<char> &dataset_di
|
||||||
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
|
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
|
||||||
const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
|
const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
|
||||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
|
||||||
std::shared_ptr<DatasetCache> cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema),
|
auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), CharToString(schema),
|
||||||
VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id,
|
VectorCharToString(columns_list), num_samples, shuffle, num_shards, shard_id,
|
||||||
shard_equal_rows, cache);
|
shard_equal_rows, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, std::shared_ptr<SchemaObj> schema,
|
|
||||||
|
TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_files,
|
||||||
|
const std::shared_ptr<SchemaObj> &schema,
|
||||||
const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
|
const std::vector<std::vector<char>> &columns_list, int64_t num_samples,
|
||||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
|
||||||
std::shared_ptr<DatasetCache> cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list),
|
auto ds = std::make_shared<TFRecordNode>(VectorCharToString(dataset_files), schema, VectorCharToString(columns_list),
|
||||||
num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache);
|
num_samples, shuffle, num_shards, shard_id, shard_equal_rows, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
@ -1816,7 +1829,7 @@ WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
WIDERFaceDataset::WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler,
|
const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
auto ds = std::make_shared<WIDERFaceNode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
|
||||||
|
@ -1853,13 +1866,12 @@ YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler,
|
YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
auto sampler_obj = sampler.get().Parse();
|
auto sampler_obj = sampler.get().Parse();
|
||||||
auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
|
auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache);
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -38,8 +38,8 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
struct Execute::ExtraInfo {
|
struct Execute::ExtraInfo {
|
||||||
std::multimap<std::string, std::vector<uint32_t>> aipp_cfg_;
|
std::multimap<std::string, std::vector<uint32_t>> aipp_cfg_;
|
||||||
bool init_with_shared_ptr_ = true; // Initial execute object with shared_ptr as default
|
bool init_with_shared_ptr_ = true; // Initial execute object with shared_ptr as default
|
||||||
|
@ -70,14 +70,14 @@ Status Execute::InitResource(MapTargetDevice device_type, uint32_t device_id) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(const std::shared_ptr<TensorOperation> &op, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
ops_.emplace_back(std::move(op));
|
ops_.emplace_back(op);
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
(void)InitResource(device_type, device_id);
|
(void)InitResource(device_type, device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(const std::shared_ptr<TensorTransform> &op, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
// Initialize the op and other context
|
// Initialize the op and other context
|
||||||
transforms_.emplace_back(op);
|
transforms_.emplace_back(op);
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_typ
|
||||||
(void)InitResource(device_type, device_id);
|
(void)InitResource(device_type, device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(const std::reference_wrapper<TensorTransform> &op, MapTargetDevice device_type, uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
std::shared_ptr<TensorOperation> operation = op.get().Parse();
|
std::shared_ptr<TensorOperation> operation = op.get().Parse();
|
||||||
ops_.emplace_back(std::move(operation));
|
ops_.emplace_back(std::move(operation));
|
||||||
|
@ -108,13 +108,15 @@ Execute::Execute(TensorTransform *op, MapTargetDevice device_type, uint32_t devi
|
||||||
(void)InitResource(device_type, device_id);
|
(void)InitResource(device_type, device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice device_type, uint32_t device_id)
|
Execute::Execute(const std::vector<std::shared_ptr<TensorOperation>> &ops, MapTargetDevice device_type,
|
||||||
: ops_(std::move(ops)), device_type_(device_type) {
|
uint32_t device_id)
|
||||||
|
: ops_(ops), device_type_(device_type) {
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
(void)InitResource(device_type, device_id);
|
(void)InitResource(device_type, device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice device_type, uint32_t device_id) {
|
Execute::Execute(const std::vector<std::shared_ptr<TensorTransform>> &ops, MapTargetDevice device_type,
|
||||||
|
uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
transforms_ = ops;
|
transforms_ = ops;
|
||||||
|
|
||||||
|
@ -123,7 +125,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
|
||||||
(void)InitResource(device_type, device_id);
|
(void)InitResource(device_type, device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice device_type,
|
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> &ops, MapTargetDevice device_type,
|
||||||
uint32_t device_id) {
|
uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
if (device_type == MapTargetDevice::kCpu) {
|
if (device_type == MapTargetDevice::kCpu) {
|
||||||
|
@ -174,7 +176,6 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
|
||||||
// Validate input tensor
|
// Validate input tensor
|
||||||
RETURN_UNEXPECTED_IF_NULL(output);
|
RETURN_UNEXPECTED_IF_NULL(output);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
|
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr.");
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
|
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");
|
||||||
|
|
||||||
// Parse TensorTransform transforms_ into TensorOperation ops_
|
// Parse TensorTransform transforms_ into TensorOperation ops_
|
||||||
|
@ -269,7 +270,6 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
|
||||||
// Validate input tensor
|
// Validate input tensor
|
||||||
RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
|
RETURN_UNEXPECTED_IF_NULL(output_tensor_list);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr.");
|
|
||||||
output_tensor_list->clear();
|
output_tensor_list->clear();
|
||||||
for (auto &tensor : input_tensor_list) {
|
for (auto &tensor : input_tensor_list) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data.");
|
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data.");
|
||||||
|
@ -374,16 +374,16 @@ std::vector<uint32_t> AippSizeFilter(const std::vector<uint32_t> &resize_para, c
|
||||||
std::vector<uint32_t> aipp_size;
|
std::vector<uint32_t> aipp_size;
|
||||||
|
|
||||||
// Special condition where (no Crop and no Resize) or (no Crop and resize with fixed ratio) will lead to dynamic input
|
// Special condition where (no Crop and no Resize) or (no Crop and resize with fixed ratio) will lead to dynamic input
|
||||||
if ((resize_para.size() == 0 || resize_para.size() == 1) && crop_para.size() == 0) {
|
if ((resize_para.empty() || resize_para.size() == 1) && crop_para.empty()) {
|
||||||
aipp_size = {0, 0};
|
aipp_size = {0, 0};
|
||||||
MS_LOG(WARNING) << "Dynamic input shape is not supported, incomplete aipp config file will be generated. Please "
|
MS_LOG(WARNING) << "Dynamic input shape is not supported, incomplete aipp config file will be generated. Please "
|
||||||
"checkout your TensorTransform input, both src_image_size_h and src_image_size will be 0.";
|
"checkout your TensorTransform input, both src_image_size_h and src_image_size will be 0.";
|
||||||
return aipp_size;
|
return aipp_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (resize_para.size() == 0) { // If only Crop operator exists
|
if (resize_para.empty()) { // If only Crop operator exists
|
||||||
aipp_size = crop_para;
|
aipp_size = crop_para;
|
||||||
} else if (crop_para.size() == 0) { // If only Resize operator with 2 parameters exists
|
} else if (crop_para.empty()) { // If only Resize operator with 2 parameters exists
|
||||||
aipp_size = resize_para;
|
aipp_size = resize_para;
|
||||||
} else { // If both of them exist
|
} else { // If both of them exist
|
||||||
if (resize_para.size() == 1) {
|
if (resize_para.size() == 1) {
|
||||||
|
@ -437,6 +437,7 @@ std::vector<float> AippStdFilter(const std::vector<uint32_t> &normalize_para) {
|
||||||
|
|
||||||
Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, const std::vector<uint32_t> &aipp_size,
|
Status AippInfoCollection(std::map<std::string, std::string> *aipp_options, const std::vector<uint32_t> &aipp_size,
|
||||||
const std::vector<uint32_t> &aipp_mean, const std::vector<float> &aipp_std) {
|
const std::vector<uint32_t> &aipp_mean, const std::vector<float> &aipp_std) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(aipp_options);
|
||||||
// Several aipp config parameters
|
// Several aipp config parameters
|
||||||
aipp_options->insert(std::make_pair("related_input_rank", "0"));
|
aipp_options->insert(std::make_pair("related_input_rank", "0"));
|
||||||
aipp_options->insert(std::make_pair("src_image_size_w", std::to_string(aipp_size[1])));
|
aipp_options->insert(std::make_pair("src_image_size_w", std::to_string(aipp_size[1])));
|
||||||
|
@ -489,10 +490,7 @@ std::string Execute::AippCfgGenerator() {
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (info_->init_with_shared_ptr_) {
|
if (info_->init_with_shared_ptr_) {
|
||||||
auto rc = ParseTransforms();
|
auto rc = ParseTransforms();
|
||||||
if (rc.IsError()) {
|
RETURN_SECOND_IF_ERROR(rc, "");
|
||||||
MS_LOG(ERROR) << "Parse transforms failed, error msg is " << rc;
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
info_->init_with_shared_ptr_ = false;
|
info_->init_with_shared_ptr_ = false;
|
||||||
}
|
}
|
||||||
std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators
|
std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators
|
||||||
|
@ -574,6 +572,7 @@ std::string Execute::AippCfgGenerator() {
|
||||||
auto rc = AippInfoCollection(&aipp_options, aipp_size, aipp_mean, aipp_std);
|
auto rc = AippInfoCollection(&aipp_options, aipp_size, aipp_mean, aipp_std);
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "Aipp information initialization failed, error msg is " << rc;
|
MS_LOG(ERROR) << "Aipp information initialization failed, error msg is " << rc;
|
||||||
|
outfile.close();
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -594,7 +593,7 @@ std::string Execute::AippCfgGenerator() {
|
||||||
return config_location;
|
return config_location;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsEmptyPtr(std::shared_ptr<TensorTransform> api_ptr) { return api_ptr == nullptr; }
|
bool IsEmptyPtr(const std::shared_ptr<TensorTransform> &api_ptr) { return api_ptr == nullptr; }
|
||||||
|
|
||||||
Status Execute::ParseTransforms() {
|
Status Execute::ParseTransforms() {
|
||||||
auto iter = std::find_if(transforms_.begin(), transforms_.end(), IsEmptyPtr);
|
auto iter = std::find_if(transforms_.begin(), transforms_.end(), IsEmptyPtr);
|
||||||
|
@ -606,7 +605,7 @@ Status Execute::ParseTransforms() {
|
||||||
|
|
||||||
if (device_type_ == MapTargetDevice::kCpu) {
|
if (device_type_ == MapTargetDevice::kCpu) {
|
||||||
(void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(ops_),
|
(void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(ops_),
|
||||||
[](std::shared_ptr<TensorTransform> operation) -> std::shared_ptr<TensorOperation> {
|
[](const std::shared_ptr<TensorTransform> &operation) -> std::shared_ptr<TensorOperation> {
|
||||||
return operation->Parse();
|
return operation->Parse();
|
||||||
});
|
});
|
||||||
} else if (device_type_ == MapTargetDevice::kAscend310) {
|
} else if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
|
@ -645,10 +644,11 @@ Status Execute::DeviceMemoryRelease() {
|
||||||
|
|
||||||
Status Execute::Run(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
|
Status Execute::Run(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
|
||||||
const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs) {
|
const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(outputs);
|
||||||
std::vector<MSTensor> transform_inputs = inputs;
|
std::vector<MSTensor> transform_inputs = inputs;
|
||||||
std::vector<MSTensor> transform_outputs;
|
std::vector<MSTensor> transform_outputs;
|
||||||
if (!data_graph.empty()) {
|
if (!data_graph.empty()) {
|
||||||
for (auto exes : data_graph) {
|
for (const auto &exes : data_graph) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(exes != nullptr, "Given execute object is null.");
|
CHECK_FAIL_RETURN_UNEXPECTED(exes != nullptr, "Given execute object is null.");
|
||||||
Status ret = exes->operator()(transform_inputs, &transform_outputs);
|
Status ret = exes->operator()(transform_inputs, &transform_outputs);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
|
@ -675,9 +675,11 @@ extern "C" {
|
||||||
void ExecuteRun_C(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
|
void ExecuteRun_C(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
|
||||||
std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs, Status *s) {
|
std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs, Status *s) {
|
||||||
Status ret = Execute::Run(data_graph, inputs, outputs);
|
Status ret = Execute::Run(data_graph, inputs, outputs);
|
||||||
|
if (s == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
*s = Status(ret);
|
*s = Status(ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,6 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "minddata/dataset/include/dataset/iterator.h"
|
#include "minddata/dataset/include/dataset/iterator.h"
|
||||||
|
|
||||||
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
|
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
|
||||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||||
#include "minddata/dataset/engine/runtime_context.h"
|
#include "minddata/dataset/engine/runtime_context.h"
|
||||||
|
@ -75,7 +76,7 @@ void Iterator::Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to build and launch the execution tree.
|
// Function to build and launch the execution tree.
|
||||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
|
Status Iterator::BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs) {
|
||||||
RETURN_UNEXPECTED_IF_NULL(ds);
|
RETURN_UNEXPECTED_IF_NULL(ds);
|
||||||
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
|
CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
|
||||||
|
@ -105,7 +106,7 @@ Status PullIterator::GetRows(int32_t num_rows, std::vector<MSTensorVec> *const r
|
||||||
}
|
}
|
||||||
|
|
||||||
MSTensorVec ms_row = {};
|
MSTensorVec ms_row = {};
|
||||||
for (auto de_tensor : md_row) {
|
for (const auto &de_tensor : md_row) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
|
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
|
||||||
ms_row.push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
|
ms_row.push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
|
||||||
}
|
}
|
||||||
|
@ -125,7 +126,7 @@ Status PullIterator::GetNextRow(MSTensorVec *const row) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto de_tensor : md_row) {
|
for (const auto &de_tensor : md_row) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
|
CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
|
||||||
row->push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
|
row->push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
|
||||||
}
|
}
|
||||||
|
@ -135,7 +136,7 @@ Status PullIterator::GetNextRow(MSTensorVec *const row) {
|
||||||
// Function to build and launch the execution tree. This function kicks off a different type of consumer
|
// Function to build and launch the execution tree. This function kicks off a different type of consumer
|
||||||
// for the tree, the reason why this is the case is due to the fact that PullBasedIterator does not need
|
// for the tree, the reason why this is the case is due to the fact that PullBasedIterator does not need
|
||||||
// to instantiate threads for each op. As such, the call to the consumer will by pass the execution tree.
|
// to instantiate threads for each op. As such, the call to the consumer will by pass the execution tree.
|
||||||
Status PullIterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
Status PullIterator::BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds) {
|
||||||
if (pull_consumer_ == nullptr) {
|
if (pull_consumer_ == nullptr) {
|
||||||
pull_consumer_ = std::make_unique<PullBasedIteratorConsumer>();
|
pull_consumer_ = std::make_unique<PullBasedIteratorConsumer>();
|
||||||
}
|
}
|
||||||
|
@ -167,7 +168,7 @@ Iterator::_Iterator &Iterator::_Iterator::operator++() {
|
||||||
cur_row_ = nullptr;
|
cur_row_ = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cur_row_ && cur_row_->size() == 0) {
|
if (cur_row_ && cur_row_->empty()) {
|
||||||
delete cur_row_;
|
delete cur_row_;
|
||||||
cur_row_ = nullptr;
|
cur_row_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -60,7 +60,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
PYBIND_REGISTER(
|
PYBIND_REGISTER(
|
||||||
AllpassBiquadOperation, 1, ([](const py::module *m) {
|
AllpassBiquadOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<audio::AllpassBiquadOperation, TensorOperation, std::shared_ptr<audio::AllpassBiquadOperation>>(
|
(void)py::class_<audio::AllpassBiquadOperation, TensorOperation, std::shared_ptr<audio::AllpassBiquadOperation>>(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,15 +13,17 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include "pybind11/numpy.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/stl_bind.h"
|
#include "pybind11/stl_bind.h"
|
||||||
|
|
||||||
#include "minddata/dataset/api/python/pybind_register.h"
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
#include "minddata/dataset/core/global_context.h"
|
|
||||||
#include "minddata/dataset/core/client.h" // DE client
|
#include "minddata/dataset/core/client.h" // DE client
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/core/global_context.h"
|
||||||
#include "pybind11/numpy.h"
|
|
||||||
#include "minddata/dataset/include/dataset/constants.h"
|
#include "minddata/dataset/include/dataset/constants.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -63,12 +65,12 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
|
||||||
.def("get_enable_autotune", &ConfigManager::enable_autotune)
|
.def("get_enable_autotune", &ConfigManager::enable_autotune)
|
||||||
.def("set_autotune_interval", &ConfigManager::set_autotune_interval)
|
.def("set_autotune_interval", &ConfigManager::set_autotune_interval)
|
||||||
.def("get_autotune_interval", &ConfigManager::autotune_interval)
|
.def("get_autotune_interval", &ConfigManager::autotune_interval)
|
||||||
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
.def("load", [](ConfigManager &c, const std::string &s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) {
|
PYBIND_REGISTER(Tensor, 0, ([](const py::module *m) {
|
||||||
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
|
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
|
||||||
.def(py::init([](py::array arr) {
|
.def(py::init([](const py::array &arr) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out));
|
THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out));
|
||||||
return out;
|
return out;
|
||||||
|
@ -107,7 +109,7 @@ PYBIND_REGISTER(DataType, 0, ([](const py::module *m) {
|
||||||
.def(py::init<std::string>())
|
.def(py::init<std::string>())
|
||||||
.def(py::self == py::self)
|
.def(py::self == py::self)
|
||||||
.def("__str__", &DataType::ToString)
|
.def("__str__", &DataType::ToString)
|
||||||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
.def("__deepcopy__", [](py::object &t, const py::dict &memo) { return t; });
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(AutoAugmentPolicy, 0, ([](const py::module *m) {
|
PYBIND_REGISTER(AutoAugmentPolicy, 0, ([](const py::module *m) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -23,7 +23,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
PYBIND_REGISTER(
|
PYBIND_REGISTER(
|
||||||
Graph, 0, ([](const py::module *m) {
|
Graph, 0, ([](const py::module *m) {
|
||||||
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
|
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
|
||||||
|
@ -51,7 +50,7 @@ PYBIND_REGISTER(
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_nodes_from_edges",
|
.def("get_nodes_from_edges",
|
||||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) {
|
[](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &edge_list) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
||||||
return out;
|
return out;
|
||||||
|
@ -70,27 +69,30 @@ PYBIND_REGISTER(
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_sampled_neighbors",
|
.def("get_sampled_neighbors",
|
||||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
[](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list,
|
||||||
std::vector<gnn::NodeType> neighbor_types, SamplingStrategy strategy) {
|
const std::vector<gnn::NodeIdType> &neighbor_nums, const std::vector<gnn::NodeType> &neighbor_types,
|
||||||
|
SamplingStrategy strategy) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out));
|
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_neg_sampled_neighbors",
|
.def("get_neg_sampled_neighbors",
|
||||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
|
[](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list, gnn::NodeIdType neighbor_num,
|
||||||
gnn::NodeType neg_neighbor_type) {
|
gnn::NodeType neg_neighbor_type) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_node_feature",
|
.def(
|
||||||
[](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
"get_node_feature",
|
||||||
|
[](gnn::GraphData &g, const std::shared_ptr<Tensor> &node_list, std::vector<gnn::FeatureType> feature_types) {
|
||||||
TensorRow out;
|
TensorRow out;
|
||||||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||||
return out.getRow();
|
return out.getRow();
|
||||||
})
|
})
|
||||||
.def("get_edge_feature",
|
.def(
|
||||||
[](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
|
"get_edge_feature",
|
||||||
|
[](gnn::GraphData &g, const std::shared_ptr<Tensor> &edge_list, std::vector<gnn::FeatureType> feature_types) {
|
||||||
TensorRow out;
|
TensorRow out;
|
||||||
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
|
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
|
||||||
return out.getRow();
|
return out.getRow();
|
||||||
|
@ -102,8 +104,9 @@ PYBIND_REGISTER(
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("random_walk",
|
.def("random_walk",
|
||||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
|
[](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list,
|
||||||
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
|
const std::vector<gnn::NodeType> &meta_path, float step_home_param, float step_away_param,
|
||||||
|
gnn::NodeIdType default_node) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
||||||
return out;
|
return out;
|
||||||
|
@ -137,6 +140,5 @@ PYBIND_REGISTER(OutputFormat, 0, ([](const py::module *m) {
|
||||||
.value("DE_FORMAT_CSR", OutputFormat::kCsr)
|
.value("DE_FORMAT_CSR", OutputFormat::kCsr)
|
||||||
.export_values();
|
.export_values();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,9 +18,12 @@
|
||||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||||
#include "minddata/dataset/api/python/pybind_register.h"
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
#include "minddata/dataset/include/dataset/constants.h"
|
#include "minddata/dataset/core/config_manager.h"
|
||||||
|
#include "minddata/dataset/core/data_type.h"
|
||||||
#include "minddata/dataset/engine/serdes.h"
|
#include "minddata/dataset/engine/serdes.h"
|
||||||
|
#include "minddata/dataset/include/dataset/constants.h"
|
||||||
#include "minddata/dataset/text/sentence_piece_vocab.h"
|
#include "minddata/dataset/text/sentence_piece_vocab.h"
|
||||||
|
#include "minddata/dataset/util/path.h"
|
||||||
|
|
||||||
// IR non-leaf nodes
|
// IR non-leaf nodes
|
||||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||||
|
@ -40,37 +43,33 @@
|
||||||
// IR non-leaf nodes - for android
|
// IR non-leaf nodes - for android
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "minddata/dataset/core/config_manager.h"
|
|
||||||
#include "minddata/dataset/core/data_type.h"
|
|
||||||
#include "minddata/dataset/util/path.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
|
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
|
||||||
.def("set_num_workers",
|
.def("set_num_workers",
|
||||||
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
|
[](const std::shared_ptr<DatasetNode> &self, std::optional<int32_t> num_workers) {
|
||||||
return num_workers ? self->SetNumWorkers(*num_workers) : self;
|
return num_workers ? self->SetNumWorkers(*num_workers) : self;
|
||||||
})
|
})
|
||||||
.def("set_cache_client",
|
.def("set_cache_client",
|
||||||
[](std::shared_ptr<DatasetNode> self, std::shared_ptr<CacheClient> cc) {
|
[](const std::shared_ptr<DatasetNode> &self, std::shared_ptr<CacheClient> cc) {
|
||||||
return self->SetDatasetCache(toDatasetCache(std::move(cc)));
|
return self->SetDatasetCache(toDatasetCache(std::move(cc)));
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"Zip",
|
"Zip",
|
||||||
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
|
[](const std::shared_ptr<DatasetNode> &self, const py::list &datasets) {
|
||||||
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
|
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
|
||||||
THROW_IF_ERROR(zip->ValidateParams());
|
THROW_IF_ERROR(zip->ValidateParams());
|
||||||
return zip;
|
return zip;
|
||||||
},
|
},
|
||||||
py::arg("datasets"))
|
py::arg("datasets"))
|
||||||
.def("to_json",
|
.def("to_json",
|
||||||
[](std::shared_ptr<DatasetNode> self, const std::string &json_filepath) {
|
[](const std::shared_ptr<DatasetNode> &self, const std::string &json_filepath) {
|
||||||
nlohmann::json args;
|
nlohmann::json args;
|
||||||
THROW_IF_ERROR(Serdes::SaveToJSON(self, json_filepath, &args));
|
THROW_IF_ERROR(Serdes::SaveToJSON(self, json_filepath, &args));
|
||||||
return args.dump();
|
return args.dump();
|
||||||
|
@ -92,23 +91,22 @@ PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
||||||
// PYBIND FOR NON-LEAF NODES
|
// PYBIND FOR NON-LEAF NODES
|
||||||
// (In alphabetical order)
|
// (In alphabetical order)
|
||||||
|
|
||||||
PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(
|
||||||
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode",
|
BatchNode, 2, ([](const py::module *m) {
|
||||||
"to create a BatchNode")
|
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode", "to create a BatchNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder,
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, int32_t batch_size, bool drop_remainder, bool pad,
|
||||||
bool pad, py::list in_col_names, py::list out_col_names, py::list col_order,
|
const py::list &in_col_names, const py::list &out_col_names, const py::list &col_order,
|
||||||
py::object size_obj, py::object map_obj, py::dict pad_info) {
|
const py::object &size_obj, const py::object &map_obj, const py::dict &pad_info) {
|
||||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
||||||
if (pad) {
|
if (pad) {
|
||||||
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
||||||
}
|
}
|
||||||
py::function size_func =
|
py::function size_func =
|
||||||
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
|
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
|
||||||
py::function map_func =
|
py::function map_func = py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
|
||||||
py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
|
auto batch = std::make_shared<BatchNode>(self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
|
||||||
auto batch = std::make_shared<BatchNode>(
|
toStringVector(out_col_names), toStringVector(col_order), size_func,
|
||||||
self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
|
map_func, c_pad_info);
|
||||||
toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info);
|
|
||||||
THROW_IF_ERROR(batch->ValidateParams());
|
THROW_IF_ERROR(batch->ValidateParams());
|
||||||
return batch;
|
return batch;
|
||||||
}));
|
}));
|
||||||
|
@ -117,10 +115,10 @@ PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
|
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
|
||||||
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
|
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names,
|
.def(py::init([](const std::shared_ptr<DatasetNode> &dataset, const py::list &column_names,
|
||||||
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes,
|
const std::vector<int32_t> &bucket_boundaries,
|
||||||
py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary,
|
const std::vector<int32_t> &bucket_batch_sizes, py::object element_length_function,
|
||||||
bool drop_remainder) {
|
const py::dict &pad_info, bool pad_to_bucket_boundary, bool drop_remainder) {
|
||||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
||||||
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
||||||
|
|
||||||
|
@ -139,10 +137,10 @@ PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
|
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
|
||||||
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
|
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab,
|
.def(py::init(
|
||||||
const std::vector<std::string> &col_names, int32_t vocab_size,
|
[](const std::shared_ptr<DatasetNode> &self, const std::shared_ptr<SentencePieceVocab> &vocab,
|
||||||
float character_coverage, SentencePieceModel model_type,
|
const std::vector<std::string> &col_names, int32_t vocab_size, float character_coverage,
|
||||||
const std::unordered_map<std::string, std::string> ¶ms) {
|
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) {
|
||||||
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
|
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
|
||||||
self, vocab, col_names, vocab_size, character_coverage, model_type, params);
|
self, vocab, col_names, vocab_size, character_coverage, model_type, params);
|
||||||
THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
|
THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
|
||||||
|
@ -153,8 +151,9 @@ PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
|
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
|
||||||
*m, "BuildVocabNode", "to create a BuildVocabNode")
|
*m, "BuildVocabNode", "to create a BuildVocabNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns,
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const std::shared_ptr<Vocab> &vocab,
|
||||||
py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) {
|
const py::list &columns, const py::tuple &freq_range, int64_t top_k,
|
||||||
|
py::list special_tokens, bool special_first) {
|
||||||
auto build_vocab =
|
auto build_vocab =
|
||||||
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
|
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
|
||||||
top_k, toStringVector(special_tokens), special_first);
|
top_k, toStringVector(special_tokens), special_first);
|
||||||
|
@ -166,8 +165,8 @@ PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
|
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
|
||||||
"to create a ConcatNode")
|
"to create a ConcatNode")
|
||||||
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, py::handle sampler,
|
.def(py::init([](const std::vector<std::shared_ptr<DatasetNode>> &datasets, py::handle sampler,
|
||||||
py::list children_flag_and_nums, py::list children_start_end_index) {
|
const py::list &children_flag_and_nums, const py::list &children_start_end_index) {
|
||||||
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
|
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
|
||||||
toPairVector(children_flag_and_nums),
|
toPairVector(children_flag_and_nums),
|
||||||
toPairVector(children_start_end_index));
|
toPairVector(children_start_end_index));
|
||||||
|
@ -179,8 +178,8 @@ PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
|
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
|
||||||
"to create a FilterNode")
|
"to create a FilterNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate,
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::object &predicate,
|
||||||
std::vector<std::string> input_columns) {
|
const std::vector<std::string> &input_columns) {
|
||||||
auto filter =
|
auto filter =
|
||||||
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
|
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
|
||||||
THROW_IF_ERROR(filter->ValidateParams());
|
THROW_IF_ERROR(filter->ValidateParams());
|
||||||
|
@ -190,8 +189,9 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
|
||||||
|
|
||||||
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
|
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::list &operations,
|
||||||
py::list output_columns, py::list project_columns,
|
const py::list &input_columns, const py::list &output_columns,
|
||||||
|
const py::list &project_columns,
|
||||||
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize,
|
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize,
|
||||||
const ManualOffloadMode offload) {
|
const ManualOffloadMode offload) {
|
||||||
auto map = std::make_shared<MapNode>(
|
auto map = std::make_shared<MapNode>(
|
||||||
|
@ -206,18 +206,20 @@ PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
|
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
|
||||||
"to create a ProjectNode")
|
"to create a ProjectNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) {
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::list &columns) {
|
||||||
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
|
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
|
||||||
THROW_IF_ERROR(project->ValidateParams());
|
THROW_IF_ERROR(project->ValidateParams());
|
||||||
return project;
|
return project;
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(
|
PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
|
||||||
RenameNode, 2, ([](const py::module *m) {
|
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
|
||||||
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode", "to create a RenameNode")
|
"to create a RenameNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list input_columns, py::list output_columns) {
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const py::list &input_columns,
|
||||||
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns), toStringVector(output_columns));
|
const py::list &output_columns) {
|
||||||
|
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
|
||||||
|
toStringVector(output_columns));
|
||||||
THROW_IF_ERROR(rename->ValidateParams());
|
THROW_IF_ERROR(rename->ValidateParams());
|
||||||
return rename;
|
return rename;
|
||||||
}));
|
}));
|
||||||
|
@ -226,7 +228,7 @@ PYBIND_REGISTER(
|
||||||
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
|
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
|
||||||
"to create a RepeatNode")
|
"to create a RepeatNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) {
|
.def(py::init([](const std::shared_ptr<DatasetNode> &input, int32_t count) {
|
||||||
auto repeat = std::make_shared<RepeatNode>(input, count);
|
auto repeat = std::make_shared<RepeatNode>(input, count);
|
||||||
THROW_IF_ERROR(repeat->ValidateParams());
|
THROW_IF_ERROR(repeat->ValidateParams());
|
||||||
return repeat;
|
return repeat;
|
||||||
|
@ -236,7 +238,8 @@ PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
|
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
|
||||||
"to create a ShuffleNode")
|
"to create a ShuffleNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) {
|
.def(py::init(
|
||||||
|
[](const std::shared_ptr<DatasetNode> &self, int32_t shuffle_size, bool reset_every_epoch) {
|
||||||
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
|
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
|
||||||
THROW_IF_ERROR(shuffle->ValidateParams());
|
THROW_IF_ERROR(shuffle->ValidateParams());
|
||||||
return shuffle;
|
return shuffle;
|
||||||
|
@ -246,7 +249,7 @@ PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
|
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
|
||||||
"to create a SkipNode")
|
"to create a SkipNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, int32_t count) {
|
||||||
auto skip = std::make_shared<SkipNode>(self, count);
|
auto skip = std::make_shared<SkipNode>(self, count);
|
||||||
THROW_IF_ERROR(skip->ValidateParams());
|
THROW_IF_ERROR(skip->ValidateParams());
|
||||||
return skip;
|
return skip;
|
||||||
|
@ -256,8 +259,8 @@ PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
|
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
|
||||||
"to create a SyncWaitNode")
|
"to create a SyncWaitNode")
|
||||||
.def(
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const std::string &condition_name,
|
||||||
py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) {
|
py::object callback) {
|
||||||
py::function callback_func =
|
py::function callback_func =
|
||||||
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
|
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
|
||||||
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
|
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
|
||||||
|
@ -269,7 +272,7 @@ PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
|
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
|
||||||
"to create a TakeNode")
|
"to create a TakeNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, int32_t count) {
|
||||||
auto take = std::make_shared<TakeNode>(self, count);
|
auto take = std::make_shared<TakeNode>(self, count);
|
||||||
THROW_IF_ERROR(take->ValidateParams());
|
THROW_IF_ERROR(take->ValidateParams());
|
||||||
return take;
|
return take;
|
||||||
|
@ -279,9 +282,9 @@ PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
|
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
|
||||||
"to create a TransferNode")
|
"to create a TransferNode")
|
||||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
|
.def(py::init([](const std::shared_ptr<DatasetNode> &self, const std::string &queue_name,
|
||||||
int32_t device_id, bool send_epoch_end, int32_t total_batch,
|
const std::string &device_type, int32_t device_id, bool send_epoch_end,
|
||||||
bool create_data_info_queue) {
|
int32_t total_batch, bool create_data_info_queue) {
|
||||||
auto transfer = std::make_shared<TransferNode>(
|
auto transfer = std::make_shared<TransferNode>(
|
||||||
self, queue_name, device_type, device_id, send_epoch_end, total_batch, create_data_info_queue);
|
self, queue_name, device_type, device_id, send_epoch_end, total_batch, create_data_info_queue);
|
||||||
THROW_IF_ERROR(transfer->ValidateParams());
|
THROW_IF_ERROR(transfer->ValidateParams());
|
||||||
|
@ -291,7 +294,7 @@ PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
|
||||||
|
|
||||||
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
|
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
|
||||||
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) {
|
.def(py::init([](const std::vector<std::shared_ptr<DatasetNode>> &datasets) {
|
||||||
auto zip = std::make_shared<ZipNode>(datasets);
|
auto zip = std::make_shared<ZipNode>(datasets);
|
||||||
THROW_IF_ERROR(zip->ValidateParams());
|
THROW_IF_ERROR(zip->ValidateParams());
|
||||||
return zip;
|
return zip;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,11 +17,10 @@
|
||||||
|
|
||||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||||
#include "minddata/dataset/api/python/pybind_register.h"
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
#include "minddata/dataset/include/dataset/constants.h"
|
|
||||||
#include "minddata/dataset/include/dataset/datasets.h"
|
|
||||||
|
|
||||||
#include "minddata/dataset/core/config_manager.h"
|
#include "minddata/dataset/core/config_manager.h"
|
||||||
#include "minddata/dataset/core/data_type.h"
|
#include "minddata/dataset/core/data_type.h"
|
||||||
|
#include "minddata/dataset/include/dataset/constants.h"
|
||||||
|
#include "minddata/dataset/include/dataset/datasets.h"
|
||||||
#include "minddata/dataset/util/path.h"
|
#include "minddata/dataset/util/path.h"
|
||||||
|
|
||||||
// IR leaf nodes
|
// IR leaf nodes
|
||||||
|
@ -29,8 +28,8 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/caltech256_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||||
|
@ -81,14 +80,13 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// PYBIND FOR LEAF NODES
|
// PYBIND FOR LEAF NODES
|
||||||
// (In alphabetical order)
|
// (In alphabetical order)
|
||||||
PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<AGNewsNode, DatasetNode, std::shared_ptr<AGNewsNode>>(*m, "AGNewsNode",
|
(void)py::class_<AGNewsNode, DatasetNode, std::shared_ptr<AGNewsNode>>(*m, "AGNewsNode",
|
||||||
"to create an AGNewsNode")
|
"to create an AGNewsNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto ag_news = std::make_shared<AGNewsNode>(dataset_dir, num_samples, toShuffleMode(shuffle),
|
auto ag_news = std::make_shared<AGNewsNode>(dataset_dir, num_samples, toShuffleMode(shuffle),
|
||||||
usage, num_shards, shard_id, nullptr);
|
usage, num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(ag_news->ValidateParams());
|
THROW_IF_ERROR(ag_news->ValidateParams());
|
||||||
|
@ -99,8 +97,8 @@ PYBIND_REGISTER(AGNewsNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(AmazonReviewNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(AmazonReviewNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<AmazonReviewNode, DatasetNode, std::shared_ptr<AmazonReviewNode>>(
|
(void)py::class_<AmazonReviewNode, DatasetNode, std::shared_ptr<AmazonReviewNode>>(
|
||||||
*m, "AmazonReviewNode", "to create an AmazonReviewNode")
|
*m, "AmazonReviewNode", "to create an AmazonReviewNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<AmazonReviewNode> amazon_review = std::make_shared<AmazonReviewNode>(
|
std::shared_ptr<AmazonReviewNode> amazon_review = std::make_shared<AmazonReviewNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(amazon_review->ValidateParams());
|
THROW_IF_ERROR(amazon_review->ValidateParams());
|
||||||
|
@ -111,7 +109,7 @@ PYBIND_REGISTER(AmazonReviewNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(Caltech256Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(Caltech256Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<Caltech256Node, DatasetNode, std::shared_ptr<Caltech256Node>>(
|
(void)py::class_<Caltech256Node, DatasetNode, std::shared_ptr<Caltech256Node>>(
|
||||||
*m, "Caltech256Node", "to create a Caltech256Node")
|
*m, "Caltech256Node", "to create a Caltech256Node")
|
||||||
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler) {
|
||||||
auto caltech256 =
|
auto caltech256 =
|
||||||
std::make_shared<Caltech256Node>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
|
std::make_shared<Caltech256Node>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(caltech256->ValidateParams());
|
THROW_IF_ERROR(caltech256->ValidateParams());
|
||||||
|
@ -122,8 +120,8 @@ PYBIND_REGISTER(Caltech256Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
|
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
|
||||||
"to create a CelebANode")
|
"to create a CelebANode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
py::list extensions) {
|
const py::handle &sampler, bool decode, const py::list &extensions) {
|
||||||
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
|
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
|
||||||
toStringSet(extensions), nullptr);
|
toStringSet(extensions), nullptr);
|
||||||
THROW_IF_ERROR(celebA->ValidateParams());
|
THROW_IF_ERROR(celebA->ValidateParams());
|
||||||
|
@ -134,7 +132,8 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
|
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
|
||||||
"to create a Cifar10Node")
|
"to create a Cifar10Node")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
|
const py::handle &sampler) {
|
||||||
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(cifar10->ValidateParams());
|
THROW_IF_ERROR(cifar10->ValidateParams());
|
||||||
return cifar10;
|
return cifar10;
|
||||||
|
@ -144,7 +143,8 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
|
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
|
||||||
"to create a Cifar100Node")
|
"to create a Cifar100Node")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto cifar100 =
|
auto cifar100 =
|
||||||
std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(cifar100->ValidateParams());
|
THROW_IF_ERROR(cifar100->ValidateParams());
|
||||||
|
@ -155,8 +155,9 @@ PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>(
|
(void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>(
|
||||||
*m, "CityscapesNode", "to create a CityscapesNode")
|
*m, "CityscapesNode", "to create a CityscapesNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, std::string quality_mode,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
std::string task, bool decode, const py::handle &sampler) {
|
const std::string &quality_mode, const std::string &task, bool decode,
|
||||||
|
const py::handle &sampler) {
|
||||||
auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode,
|
auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode,
|
||||||
toSamplerObj(sampler), nullptr);
|
toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(cityscapes->ValidateParams());
|
THROW_IF_ERROR(cityscapes->ValidateParams());
|
||||||
|
@ -167,8 +168,8 @@ PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
|
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
|
||||||
"to create a CLUENode")
|
"to create a CLUENode")
|
||||||
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples,
|
.def(py::init([](const py::list &files, const std::string &task, const std::string &usage,
|
||||||
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<CLUENode> clue_node =
|
std::shared_ptr<CLUENode> clue_node =
|
||||||
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
|
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
|
||||||
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
|
@ -180,8 +181,9 @@ PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
|
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
|
||||||
"to create a CocoNode")
|
"to create a CocoNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
|
.def(py::init([](const std::string &dataset_dir, const std::string &annotation_file,
|
||||||
bool decode, const py::handle &sampler, bool extra_metadata) {
|
const std::string &task, bool decode, const py::handle &sampler,
|
||||||
|
bool extra_metadata) {
|
||||||
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
|
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
|
||||||
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
|
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
|
||||||
THROW_IF_ERROR(coco->ValidateParams());
|
THROW_IF_ERROR(coco->ValidateParams());
|
||||||
|
@ -192,8 +194,8 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CoNLL2000Node, DatasetNode, std::shared_ptr<CoNLL2000Node>>(
|
(void)py::class_<CoNLL2000Node, DatasetNode, std::shared_ptr<CoNLL2000Node>>(
|
||||||
*m, "CoNLL2000Node", "to create a CoNLL2000Node")
|
*m, "CoNLL2000Node", "to create a CoNLL2000Node")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<CoNLL2000Node> conll2000 = std::make_shared<CoNLL2000Node>(
|
std::shared_ptr<CoNLL2000Node> conll2000 = std::make_shared<CoNLL2000Node>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(conll2000->ValidateParams());
|
THROW_IF_ERROR(conll2000->ValidateParams());
|
||||||
|
@ -203,9 +205,9 @@ PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) {
|
||||||
|
|
||||||
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
||||||
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
.def(py::init([](const std::vector<std::string> &csv_files, char field_delim,
|
||||||
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
|
const py::list &column_defaults, const std::vector<std::string> &column_names,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto csv =
|
auto csv =
|
||||||
std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
|
std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
|
||||||
num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
|
@ -217,8 +219,8 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<DBpediaNode, DatasetNode, std::shared_ptr<DBpediaNode>>(*m, "DBpediaNode",
|
(void)py::class_<DBpediaNode, DatasetNode, std::shared_ptr<DBpediaNode>>(*m, "DBpediaNode",
|
||||||
"to create a DBpediaNode")
|
"to create a DBpediaNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto dbpedia = std::make_shared<DBpediaNode>(
|
auto dbpedia = std::make_shared<DBpediaNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(dbpedia->ValidateParams());
|
THROW_IF_ERROR(dbpedia->ValidateParams());
|
||||||
|
@ -229,8 +231,9 @@ PYBIND_REGISTER(DBpediaNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
|
(void)py::class_<DIV2KNode, DatasetNode, std::shared_ptr<DIV2KNode>>(*m, "DIV2KNode",
|
||||||
"to create a DIV2KNode")
|
"to create a DIV2KNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, std::string downgrade, int32_t scale,
|
.def(
|
||||||
bool decode, py::handle sampler) {
|
py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
|
const std::string &downgrade, int32_t scale, bool decode, const py::handle &sampler) {
|
||||||
auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
|
auto div2k = std::make_shared<DIV2KNode>(dataset_dir, usage, downgrade, scale, decode,
|
||||||
toSamplerObj(sampler), nullptr);
|
toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(div2k->ValidateParams());
|
THROW_IF_ERROR(div2k->ValidateParams());
|
||||||
|
@ -241,7 +244,8 @@ PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<EMnistNode, DatasetNode, std::shared_ptr<EMnistNode>>(*m, "EMnistNode",
|
(void)py::class_<EMnistNode, DatasetNode, std::shared_ptr<EMnistNode>>(*m, "EMnistNode",
|
||||||
"to create an EMnistNode")
|
"to create an EMnistNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string name, std::string usage, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const std::string &name, const std::string &usage,
|
||||||
|
const py::handle &sampler) {
|
||||||
auto emnist =
|
auto emnist =
|
||||||
std::make_shared<EMnistNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
|
std::make_shared<EMnistNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(emnist->ValidateParams());
|
THROW_IF_ERROR(emnist->ValidateParams());
|
||||||
|
@ -252,8 +256,8 @@ PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<EnWik9Node, DatasetNode, std::shared_ptr<EnWik9Node>>(*m, "EnWik9Node",
|
(void)py::class_<EnWik9Node, DatasetNode, std::shared_ptr<EnWik9Node>>(*m, "EnWik9Node",
|
||||||
"to create an EnWik9Node")
|
"to create an EnWik9Node")
|
||||||
.def(py::init([](std::string dataset_dir, int32_t num_samples, int32_t shuffle, int32_t num_shards,
|
.def(py::init([](const std::string &dataset_dir, int32_t num_samples, int32_t shuffle,
|
||||||
int32_t shard_id) {
|
int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<EnWik9Node> en_wik9 = std::make_shared<EnWik9Node>(
|
std::shared_ptr<EnWik9Node> en_wik9 = std::make_shared<EnWik9Node>(
|
||||||
dataset_dir, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(en_wik9->ValidateParams());
|
THROW_IF_ERROR(en_wik9->ValidateParams());
|
||||||
|
@ -264,8 +268,8 @@ PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>(
|
(void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>(
|
||||||
*m, "FakeImageNode", "to create a FakeImageNode")
|
*m, "FakeImageNode", "to create a FakeImageNode")
|
||||||
.def(py::init([](int32_t num_images, const std::vector<int32_t> image_size, int32_t num_classes,
|
.def(py::init([](int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||||
int32_t base_seed, py::handle sampler) {
|
int32_t base_seed, const py::handle &sampler) {
|
||||||
auto fake_image = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed,
|
auto fake_image = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed,
|
||||||
toSamplerObj(sampler), nullptr);
|
toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(fake_image->ValidateParams());
|
THROW_IF_ERROR(fake_image->ValidateParams());
|
||||||
|
@ -276,7 +280,8 @@ PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(FashionMnistNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(FashionMnistNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<FashionMnistNode, DatasetNode, std::shared_ptr<FashionMnistNode>>(
|
(void)py::class_<FashionMnistNode, DatasetNode, std::shared_ptr<FashionMnistNode>>(
|
||||||
*m, "FashionMnistNode", "to create a FashionMnistNode")
|
*m, "FashionMnistNode", "to create a FashionMnistNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto fashion_mnist =
|
auto fashion_mnist =
|
||||||
std::make_shared<FashionMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
std::make_shared<FashionMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(fashion_mnist->ValidateParams());
|
THROW_IF_ERROR(fashion_mnist->ValidateParams());
|
||||||
|
@ -284,12 +289,13 @@ PYBIND_REGISTER(FashionMnistNode, 2, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(
|
PYBIND_REGISTER(FlickrNode, 2, ([](const py::module *m) {
|
||||||
FlickrNode, 2, ([](const py::module *m) {
|
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode",
|
||||||
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
|
"to create a FlickrNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, const py::handle &sampler) {
|
.def(py::init([](const std::string &dataset_dir, const std::string &annotation_file, bool decode,
|
||||||
auto flickr =
|
const py::handle &sampler) {
|
||||||
std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, toSamplerObj(sampler), nullptr);
|
auto flickr = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode,
|
||||||
|
toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(flickr->ValidateParams());
|
THROW_IF_ERROR(flickr->ValidateParams());
|
||||||
return flickr;
|
return flickr;
|
||||||
}));
|
}));
|
||||||
|
@ -298,17 +304,18 @@ PYBIND_REGISTER(
|
||||||
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
||||||
*m, "GeneratorNode", "to create a GeneratorNode")
|
*m, "GeneratorNode", "to create a GeneratorNode")
|
||||||
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
|
.def(
|
||||||
const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler,
|
py::init([](const py::function &generator_function, const std::vector<std::string> &column_names,
|
||||||
uint32_t num_parallel_workers) {
|
const std::vector<DataType> &column_types, int64_t dataset_len,
|
||||||
|
const py::handle &sampler, uint32_t num_parallel_workers) {
|
||||||
auto gen =
|
auto gen =
|
||||||
std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len,
|
std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len,
|
||||||
toSamplerObj(sampler), num_parallel_workers);
|
toSamplerObj(sampler), num_parallel_workers);
|
||||||
THROW_IF_ERROR(gen->ValidateParams());
|
THROW_IF_ERROR(gen->ValidateParams());
|
||||||
return gen;
|
return gen;
|
||||||
}))
|
}))
|
||||||
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema,
|
.def(py::init([](const py::function &generator_function, const std::shared_ptr<SchemaObj> &schema,
|
||||||
int64_t dataset_len, py::handle sampler, uint32_t num_parallel_workers) {
|
int64_t dataset_len, const py::handle &sampler, uint32_t num_parallel_workers) {
|
||||||
auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len,
|
auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len,
|
||||||
toSamplerObj(sampler), num_parallel_workers);
|
toSamplerObj(sampler), num_parallel_workers);
|
||||||
THROW_IF_ERROR(gen->ValidateParams());
|
THROW_IF_ERROR(gen->ValidateParams());
|
||||||
|
@ -319,8 +326,8 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
|
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
|
||||||
*m, "ImageFolderNode", "to create an ImageFolderNode")
|
*m, "ImageFolderNode", "to create an ImageFolderNode")
|
||||||
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
|
.def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler,
|
||||||
py::dict class_indexing) {
|
const py::list &extensions, const py::dict &class_indexing) {
|
||||||
// Don't update recursive to true
|
// Don't update recursive to true
|
||||||
bool recursive = false; // Will be removed in future PR
|
bool recursive = false; // Will be removed in future PR
|
||||||
auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
|
auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
|
||||||
|
@ -334,7 +341,8 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<IMDBNode, DatasetNode, std::shared_ptr<IMDBNode>>(*m, "IMDBNode",
|
(void)py::class_<IMDBNode, DatasetNode, std::shared_ptr<IMDBNode>>(*m, "IMDBNode",
|
||||||
"to create an IMDBNode")
|
"to create an IMDBNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto imdb = std::make_shared<IMDBNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
auto imdb = std::make_shared<IMDBNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(imdb->ValidateParams());
|
THROW_IF_ERROR(imdb->ValidateParams());
|
||||||
return imdb;
|
return imdb;
|
||||||
|
@ -344,8 +352,9 @@ PYBIND_REGISTER(IMDBNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>(
|
(void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>(
|
||||||
*m, "IWSLT2016Node", "to create an IWSLT2016Node")
|
*m, "IWSLT2016Node", "to create an IWSLT2016Node")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, std::vector<std::string> language_pair,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
std::string valid_set, std::string test_set, int64_t num_samples, int32_t shuffle,
|
const std::vector<std::string> &language_pair, const std::string &valid_set,
|
||||||
|
const std::string &test_set, int64_t num_samples, int32_t shuffle,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<IWSLT2016Node> iwslt2016 = std::make_shared<IWSLT2016Node>(
|
std::shared_ptr<IWSLT2016Node> iwslt2016 = std::make_shared<IWSLT2016Node>(
|
||||||
dataset_dir, usage, language_pair, valid_set, test_set, num_samples, toShuffleMode(shuffle),
|
dataset_dir, usage, language_pair, valid_set, test_set, num_samples, toShuffleMode(shuffle),
|
||||||
|
@ -358,8 +367,9 @@ PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<IWSLT2017Node, DatasetNode, std::shared_ptr<IWSLT2017Node>>(
|
(void)py::class_<IWSLT2017Node, DatasetNode, std::shared_ptr<IWSLT2017Node>>(
|
||||||
*m, "IWSLT2017Node", "to create an IWSLT2017Node")
|
*m, "IWSLT2017Node", "to create an IWSLT2017Node")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, std::vector<std::string> language_pair,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
const std::vector<std::string> &language_pair, int64_t num_samples,
|
||||||
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<IWSLT2017Node> iwslt2017 =
|
std::shared_ptr<IWSLT2017Node> iwslt2017 =
|
||||||
std::make_shared<IWSLT2017Node>(dataset_dir, usage, language_pair, num_samples,
|
std::make_shared<IWSLT2017Node>(dataset_dir, usage, language_pair, num_samples,
|
||||||
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
|
@ -371,7 +381,8 @@ PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<KMnistNode, DatasetNode, std::shared_ptr<KMnistNode>>(*m, "KMnistNode",
|
(void)py::class_<KMnistNode, DatasetNode, std::shared_ptr<KMnistNode>>(*m, "KMnistNode",
|
||||||
"to create a KMnistNode")
|
"to create a KMnistNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto kmnist = std::make_shared<KMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
auto kmnist = std::make_shared<KMnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(kmnist->ValidateParams());
|
THROW_IF_ERROR(kmnist->ValidateParams());
|
||||||
return kmnist;
|
return kmnist;
|
||||||
|
@ -381,7 +392,7 @@ PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<LJSpeechNode, DatasetNode, std::shared_ptr<LJSpeechNode>>(*m, "LJSpeechNode",
|
(void)py::class_<LJSpeechNode, DatasetNode, std::shared_ptr<LJSpeechNode>>(*m, "LJSpeechNode",
|
||||||
"to create a LJSpeechNode")
|
"to create a LJSpeechNode")
|
||||||
.def(py::init([](std::string dataset_dir, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const py::handle &sampler) {
|
||||||
auto lj_speech = std::make_shared<LJSpeechNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
auto lj_speech = std::make_shared<LJSpeechNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(lj_speech->ValidateParams());
|
THROW_IF_ERROR(lj_speech->ValidateParams());
|
||||||
return lj_speech;
|
return lj_speech;
|
||||||
|
@ -391,8 +402,8 @@ PYBIND_REGISTER(LJSpeechNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
|
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
|
||||||
"to create a ManifestNode")
|
"to create a ManifestNode")
|
||||||
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
|
.def(py::init([](const std::string &dataset_file, const std::string &usage,
|
||||||
py::dict class_indexing, bool decode) {
|
const py::handle &sampler, const py::dict &class_indexing, bool decode) {
|
||||||
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
|
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
|
||||||
toStringMap(class_indexing), decode, nullptr);
|
toStringMap(class_indexing), decode, nullptr);
|
||||||
THROW_IF_ERROR(manifest->ValidateParams());
|
THROW_IF_ERROR(manifest->ValidateParams());
|
||||||
|
@ -400,29 +411,30 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(
|
||||||
|
MindDataNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
|
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
|
||||||
"to create a MindDataNode")
|
"to create a MindDataNode")
|
||||||
.def(py::init([](std::string dataset_file, py::list columns_list, py::handle sampler,
|
.def(py::init([](const std::string &dataset_file, const py::list &columns_list, const py::handle &sampler,
|
||||||
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
|
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
|
||||||
nlohmann::json padded_sample_json;
|
nlohmann::json padded_sample_json;
|
||||||
std::map<std::string, std::string> sample_bytes;
|
std::map<std::string, std::string> sample_bytes;
|
||||||
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
||||||
auto minddata = std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
|
auto minddata =
|
||||||
toSamplerObj(sampler, true), padded_sample_json,
|
std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list), toSamplerObj(sampler, true),
|
||||||
num_padded, shuffle_mode, nullptr);
|
padded_sample_json, num_padded, shuffle_mode, nullptr);
|
||||||
minddata->SetSampleBytes(&sample_bytes);
|
minddata->SetSampleBytes(&sample_bytes);
|
||||||
THROW_IF_ERROR(minddata->ValidateParams());
|
THROW_IF_ERROR(minddata->ValidateParams());
|
||||||
return minddata;
|
return minddata;
|
||||||
}))
|
}))
|
||||||
.def(py::init([](py::list dataset_file, py::list columns_list, py::handle sampler,
|
.def(py::init([](const py::list &dataset_file, const py::list &columns_list, const py::handle &sampler,
|
||||||
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
|
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
|
||||||
nlohmann::json padded_sample_json;
|
nlohmann::json padded_sample_json;
|
||||||
std::map<std::string, std::string> sample_bytes;
|
std::map<std::string, std::string> sample_bytes;
|
||||||
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
||||||
auto minddata = std::make_shared<MindDataNode>(
|
auto minddata = std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list),
|
||||||
toStringVector(dataset_file), toStringVector(columns_list), toSamplerObj(sampler, true),
|
toSamplerObj(sampler, true), padded_sample_json, num_padded,
|
||||||
padded_sample_json, num_padded, shuffle_mode, nullptr);
|
shuffle_mode, nullptr);
|
||||||
minddata->SetSampleBytes(&sample_bytes);
|
minddata->SetSampleBytes(&sample_bytes);
|
||||||
THROW_IF_ERROR(minddata->ValidateParams());
|
THROW_IF_ERROR(minddata->ValidateParams());
|
||||||
return minddata;
|
return minddata;
|
||||||
|
@ -432,7 +444,8 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
|
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
|
||||||
"to create an MnistNode")
|
"to create an MnistNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(mnist->ValidateParams());
|
THROW_IF_ERROR(mnist->ValidateParams());
|
||||||
return mnist;
|
return mnist;
|
||||||
|
@ -442,8 +455,8 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
|
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
|
||||||
*m, "PennTreebankNode", "to create a PennTreebankNode")
|
*m, "PennTreebankNode", "to create a PennTreebankNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int32_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto penn_treebank = std::make_shared<PennTreebankNode>(
|
auto penn_treebank = std::make_shared<PennTreebankNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(penn_treebank->ValidateParams());
|
THROW_IF_ERROR(penn_treebank->ValidateParams());
|
||||||
|
@ -454,7 +467,8 @@ PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<PhotoTourNode, DatasetNode, std::shared_ptr<PhotoTourNode>>(
|
(void)py::class_<PhotoTourNode, DatasetNode, std::shared_ptr<PhotoTourNode>>(
|
||||||
*m, "PhotoTourNode", "to create a PhotoTourNode")
|
*m, "PhotoTourNode", "to create a PhotoTourNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string name, std::string usage, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const std::string &name, const std::string &usage,
|
||||||
|
const py::handle &sampler) {
|
||||||
auto photo_tour =
|
auto photo_tour =
|
||||||
std::make_shared<PhotoTourNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
|
std::make_shared<PhotoTourNode>(dataset_dir, name, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(photo_tour->ValidateParams());
|
THROW_IF_ERROR(photo_tour->ValidateParams());
|
||||||
|
@ -465,8 +479,8 @@ PYBIND_REGISTER(PhotoTourNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(Places365Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(Places365Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<Places365Node, DatasetNode, std::shared_ptr<Places365Node>>(
|
(void)py::class_<Places365Node, DatasetNode, std::shared_ptr<Places365Node>>(
|
||||||
*m, "Places365Node", "to create a Places365Node")
|
*m, "Places365Node", "to create a Places365Node")
|
||||||
.def(py::init(
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool small, bool decode,
|
||||||
[](std::string dataset_dir, std::string usage, bool small, bool decode, py::handle sampler) {
|
const py::handle &sampler) {
|
||||||
auto places365 = std::make_shared<Places365Node>(dataset_dir, usage, small, decode,
|
auto places365 = std::make_shared<Places365Node>(dataset_dir, usage, small, decode,
|
||||||
toSamplerObj(sampler), nullptr);
|
toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(places365->ValidateParams());
|
THROW_IF_ERROR(places365->ValidateParams());
|
||||||
|
@ -477,7 +491,8 @@ PYBIND_REGISTER(Places365Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<QMnistNode, DatasetNode, std::shared_ptr<QMnistNode>>(*m, "QMnistNode",
|
(void)py::class_<QMnistNode, DatasetNode, std::shared_ptr<QMnistNode>>(*m, "QMnistNode",
|
||||||
"to create a QMnistNode")
|
"to create a QMnistNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, bool compat, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool compat,
|
||||||
|
const py::handle &sampler) {
|
||||||
auto qmnist =
|
auto qmnist =
|
||||||
std::make_shared<QMnistNode>(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr);
|
std::make_shared<QMnistNode>(dataset_dir, usage, compat, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(qmnist->ValidateParams());
|
THROW_IF_ERROR(qmnist->ValidateParams());
|
||||||
|
@ -485,18 +500,16 @@ PYBIND_REGISTER(QMnistNode, 2, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(
|
||||||
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
|
RandomNode, 2, ([](const py::module *m) {
|
||||||
"to create a RandomNode")
|
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
|
||||||
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) {
|
.def(py::init([](int32_t total_rows, const std::shared_ptr<SchemaObj> &schema, const py::list &columns_list) {
|
||||||
auto random_node =
|
auto random_node = std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
|
||||||
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
|
|
||||||
THROW_IF_ERROR(random_node->ValidateParams());
|
THROW_IF_ERROR(random_node->ValidateParams());
|
||||||
return random_node;
|
return random_node;
|
||||||
}))
|
}))
|
||||||
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) {
|
.def(py::init([](int32_t total_rows, const std::string &schema, const py::list &columns_list) {
|
||||||
auto random_node =
|
auto random_node = std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
|
||||||
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
|
|
||||||
THROW_IF_ERROR(random_node->ValidateParams());
|
THROW_IF_ERROR(random_node->ValidateParams());
|
||||||
return random_node;
|
return random_node;
|
||||||
}));
|
}));
|
||||||
|
@ -505,7 +518,7 @@ PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
|
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
|
||||||
"to create an SBUNode")
|
"to create an SBUNode")
|
||||||
.def(py::init([](std::string dataset_dir, bool decode, const py::handle &sampler) {
|
.def(py::init([](const std::string &dataset_dir, bool decode, const py::handle &sampler) {
|
||||||
auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
|
auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(sbu->ValidateParams());
|
THROW_IF_ERROR(sbu->ValidateParams());
|
||||||
return sbu;
|
return sbu;
|
||||||
|
@ -515,7 +528,7 @@ PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(SemeionNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(SemeionNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<SemeionNode, DatasetNode, std::shared_ptr<SemeionNode>>(*m, "SemeionNode",
|
(void)py::class_<SemeionNode, DatasetNode, std::shared_ptr<SemeionNode>>(*m, "SemeionNode",
|
||||||
"to create a SemeionNode")
|
"to create a SemeionNode")
|
||||||
.def(py::init([](std::string dataset_dir, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const py::handle &sampler) {
|
||||||
auto semeion = std::make_shared<SemeionNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
auto semeion = std::make_shared<SemeionNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(semeion->ValidateParams());
|
THROW_IF_ERROR(semeion->ValidateParams());
|
||||||
return semeion;
|
return semeion;
|
||||||
|
@ -525,8 +538,8 @@ PYBIND_REGISTER(SemeionNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(SogouNewsNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(SogouNewsNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<SogouNewsNode, DatasetNode, std::shared_ptr<SogouNewsNode>>(
|
(void)py::class_<SogouNewsNode, DatasetNode, std::shared_ptr<SogouNewsNode>>(
|
||||||
*m, "SogouNewsNode", "to create a SogouNewsNode")
|
*m, "SogouNewsNode", "to create a SogouNewsNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto sogou_news = std::make_shared<SogouNewsNode>(
|
auto sogou_news = std::make_shared<SogouNewsNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(sogou_news->ValidateParams());
|
THROW_IF_ERROR(sogou_news->ValidateParams());
|
||||||
|
@ -537,7 +550,8 @@ PYBIND_REGISTER(SogouNewsNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<SpeechCommandsNode, DatasetNode, std::shared_ptr<SpeechCommandsNode>>(
|
(void)py::class_<SpeechCommandsNode, DatasetNode, std::shared_ptr<SpeechCommandsNode>>(
|
||||||
*m, "SpeechCommandsNode", "to create a SpeechCommandsNode")
|
*m, "SpeechCommandsNode", "to create a SpeechCommandsNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto speech_commands =
|
auto speech_commands =
|
||||||
std::make_shared<SpeechCommandsNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
std::make_shared<SpeechCommandsNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(speech_commands->ValidateParams());
|
THROW_IF_ERROR(speech_commands->ValidateParams());
|
||||||
|
@ -548,7 +562,8 @@ PYBIND_REGISTER(SpeechCommandsNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<STL10Node, DatasetNode, std::shared_ptr<STL10Node>>(*m, "STL10Node",
|
(void)py::class_<STL10Node, DatasetNode, std::shared_ptr<STL10Node>>(*m, "STL10Node",
|
||||||
"to create a STL10Node")
|
"to create a STL10Node")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
|
.def(
|
||||||
|
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
|
||||||
auto stl10 = std::make_shared<STL10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
auto stl10 = std::make_shared<STL10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(stl10->ValidateParams());
|
THROW_IF_ERROR(stl10->ValidateParams());
|
||||||
return stl10;
|
return stl10;
|
||||||
|
@ -558,8 +573,9 @@ PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
|
(void)py::class_<TedliumNode, DatasetNode, std::shared_ptr<TedliumNode>>(*m, "TedliumNode",
|
||||||
"to create a TedliumNode")
|
"to create a TedliumNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string release, std::string usage,
|
.def(
|
||||||
std::string extensions, py::handle sampler) {
|
py::init([](const std::string &dataset_dir, const std::string &release, const std::string &usage,
|
||||||
|
const std::string &extensions, const py::handle &sampler) {
|
||||||
auto tedlium = std::make_shared<TedliumNode>(dataset_dir, release, usage, extensions,
|
auto tedlium = std::make_shared<TedliumNode>(dataset_dir, release, usage, extensions,
|
||||||
toSamplerObj(sampler), nullptr);
|
toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(tedlium->ValidateParams());
|
THROW_IF_ERROR(tedlium->ValidateParams());
|
||||||
|
@ -570,8 +586,8 @@ PYBIND_REGISTER(TedliumNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
||||||
"to create a TextFileNode")
|
"to create a TextFileNode")
|
||||||
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
|
.def(py::init([](const py::list &dataset_files, int32_t num_samples, int32_t shuffle,
|
||||||
int32_t shard_id) {
|
int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<TextFileNode> textfile_node =
|
std::shared_ptr<TextFileNode> textfile_node =
|
||||||
std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
|
std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
|
||||||
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
|
@ -583,8 +599,8 @@ PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
|
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
|
||||||
"to create a TFRecordNode")
|
"to create a TFRecordNode")
|
||||||
.def(py::init([](const py::list dataset_files, std::shared_ptr<SchemaObj> schema,
|
.def(py::init([](const py::list &dataset_files, const std::shared_ptr<SchemaObj> &schema,
|
||||||
const py::list columns_list, int64_t num_samples, int32_t shuffle,
|
const py::list &columns_list, int64_t num_samples, int32_t shuffle,
|
||||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
|
int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
|
||||||
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
||||||
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
||||||
|
@ -592,9 +608,9 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
||||||
THROW_IF_ERROR(tfrecord->ValidateParams());
|
THROW_IF_ERROR(tfrecord->ValidateParams());
|
||||||
return tfrecord;
|
return tfrecord;
|
||||||
}))
|
}))
|
||||||
.def(py::init([](const py::list dataset_files, std::string schema, const py::list columns_list,
|
.def(py::init([](const py::list &dataset_files, const std::string &schema,
|
||||||
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
|
const py::list &columns_list, int64_t num_samples, int32_t shuffle,
|
||||||
bool shard_equal_rows) {
|
int32_t num_shards, int32_t shard_id, bool shard_equal_rows) {
|
||||||
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
||||||
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
|
||||||
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
|
||||||
|
@ -606,8 +622,8 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<UDPOSNode, DatasetNode, std::shared_ptr<UDPOSNode>>(*m, "UDPOSNode",
|
(void)py::class_<UDPOSNode, DatasetNode, std::shared_ptr<UDPOSNode>>(*m, "UDPOSNode",
|
||||||
"to create an UDPOSNode")
|
"to create an UDPOSNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<UDPOSNode> udpos = std::make_shared<UDPOSNode>(
|
std::shared_ptr<UDPOSNode> udpos = std::make_shared<UDPOSNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(udpos->ValidateParams());
|
THROW_IF_ERROR(udpos->ValidateParams());
|
||||||
|
@ -618,8 +634,8 @@ PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
|
(void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
|
||||||
"to create an USPSNode")
|
"to create an USPSNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int32_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
|
auto usps = std::make_shared<USPSNode>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
|
||||||
num_shards, shard_id, nullptr);
|
num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(usps->ValidateParams());
|
THROW_IF_ERROR(usps->ValidateParams());
|
||||||
|
@ -629,7 +645,7 @@ PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
|
||||||
|
|
||||||
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string task, std::string usage,
|
.def(py::init([](const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||||
const py::dict &class_indexing, bool decode, const py::handle &sampler,
|
const py::dict &class_indexing, bool decode, const py::handle &sampler,
|
||||||
bool extra_metadata) {
|
bool extra_metadata) {
|
||||||
std::shared_ptr<VOCNode> voc =
|
std::shared_ptr<VOCNode> voc =
|
||||||
|
@ -643,7 +659,8 @@ PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(WIDERFaceNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(WIDERFaceNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<WIDERFaceNode, DatasetNode, std::shared_ptr<WIDERFaceNode>>(
|
(void)py::class_<WIDERFaceNode, DatasetNode, std::shared_ptr<WIDERFaceNode>>(
|
||||||
*m, "WIDERFaceNode", "to create a WIDERFaceNode")
|
*m, "WIDERFaceNode", "to create a WIDERFaceNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, bool decode, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
|
||||||
|
const py::handle &sampler) {
|
||||||
auto wider_face =
|
auto wider_face =
|
||||||
std::make_shared<WIDERFaceNode>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
|
std::make_shared<WIDERFaceNode>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(wider_face->ValidateParams());
|
THROW_IF_ERROR(wider_face->ValidateParams());
|
||||||
|
@ -654,8 +671,8 @@ PYBIND_REGISTER(WIDERFaceNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<WikiTextNode, DatasetNode, std::shared_ptr<WikiTextNode>>(*m, "WikiTextNode",
|
(void)py::class_<WikiTextNode, DatasetNode, std::shared_ptr<WikiTextNode>>(*m, "WikiTextNode",
|
||||||
"to create a WikiTextNode")
|
"to create a WikiTextNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int32_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int32_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto wiki_text = std::make_shared<WikiTextNode>(
|
auto wiki_text = std::make_shared<WikiTextNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(wiki_text->ValidateParams());
|
THROW_IF_ERROR(wiki_text->ValidateParams());
|
||||||
|
@ -666,8 +683,8 @@ PYBIND_REGISTER(WikiTextNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<YahooAnswersNode, DatasetNode, std::shared_ptr<YahooAnswersNode>>(
|
(void)py::class_<YahooAnswersNode, DatasetNode, std::shared_ptr<YahooAnswersNode>>(
|
||||||
*m, "YahooAnswersNode", "to create a YahooAnswersNode")
|
*m, "YahooAnswersNode", "to create a YahooAnswersNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
auto yahoo_answers = std::make_shared<YahooAnswersNode>(
|
auto yahoo_answers = std::make_shared<YahooAnswersNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(yahoo_answers->ValidateParams());
|
THROW_IF_ERROR(yahoo_answers->ValidateParams());
|
||||||
|
@ -678,8 +695,8 @@ PYBIND_REGISTER(YahooAnswersNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(YelpReviewNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(YelpReviewNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<YelpReviewNode, DatasetNode, std::shared_ptr<YelpReviewNode>>(
|
(void)py::class_<YelpReviewNode, DatasetNode, std::shared_ptr<YelpReviewNode>>(
|
||||||
*m, "YelpReviewNode", "to create a YelpReviewNode")
|
*m, "YelpReviewNode", "to create a YelpReviewNode")
|
||||||
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
int32_t num_shards, int32_t shard_id) {
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
std::shared_ptr<YelpReviewNode> yelp_review = std::make_shared<YelpReviewNode>(
|
std::shared_ptr<YelpReviewNode> yelp_review = std::make_shared<YelpReviewNode>(
|
||||||
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
THROW_IF_ERROR(yelp_review->ValidateParams());
|
THROW_IF_ERROR(yelp_review->ValidateParams());
|
||||||
|
@ -690,7 +707,7 @@ PYBIND_REGISTER(YelpReviewNode, 2, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<YesNoNode, DatasetNode, std::shared_ptr<YesNoNode>>(*m, "YesNoNode",
|
(void)py::class_<YesNoNode, DatasetNode, std::shared_ptr<YesNoNode>>(*m, "YesNoNode",
|
||||||
"to create a YesNoNode")
|
"to create a YesNoNode")
|
||||||
.def(py::init([](std::string dataset_dir, py::handle sampler) {
|
.def(py::init([](const std::string &dataset_dir, const py::handle &sampler) {
|
||||||
auto yes_no = std::make_shared<YesNoNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
auto yes_no = std::make_shared<YesNoNode>(dataset_dir, toSamplerObj(sampler), nullptr);
|
||||||
THROW_IF_ERROR(yes_no->ValidateParams());
|
THROW_IF_ERROR(yes_no->ValidateParams());
|
||||||
return yes_no;
|
return yes_no;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,6 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||||
#include "minddata/dataset/api/python/pybind_register.h"
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
#include "minddata/dataset/include/dataset/transforms.h"
|
#include "minddata/dataset/include/dataset/transforms.h"
|
||||||
|
@ -112,7 +113,7 @@ PYBIND_REGISTER(BoundingBoxAugmentOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<vision::BoundingBoxAugmentOperation, TensorOperation,
|
(void)py::class_<vision::BoundingBoxAugmentOperation, TensorOperation,
|
||||||
std::shared_ptr<vision::BoundingBoxAugmentOperation>>(*m,
|
std::shared_ptr<vision::BoundingBoxAugmentOperation>>(*m,
|
||||||
"BoundingBoxAugmentOperation")
|
"BoundingBoxAugmentOperation")
|
||||||
.def(py::init([](const py::object transform, float ratio) {
|
.def(py::init([](const py::object &transform, float ratio) {
|
||||||
auto bounding_box_augment = std::make_shared<vision::BoundingBoxAugmentOperation>(
|
auto bounding_box_augment = std::make_shared<vision::BoundingBoxAugmentOperation>(
|
||||||
std::move(toTensorOperation(transform)), ratio);
|
std::move(toTensorOperation(transform)), ratio);
|
||||||
THROW_IF_ERROR(bounding_box_augment->ValidateParams());
|
THROW_IF_ERROR(bounding_box_augment->ValidateParams());
|
||||||
|
@ -207,7 +208,7 @@ PYBIND_REGISTER(
|
||||||
GaussianBlurOperation, 1, ([](const py::module *m) {
|
GaussianBlurOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<vision::GaussianBlurOperation, TensorOperation, std::shared_ptr<vision::GaussianBlurOperation>>(
|
(void)py::class_<vision::GaussianBlurOperation, TensorOperation, std::shared_ptr<vision::GaussianBlurOperation>>(
|
||||||
*m, "GaussianBlurOperation")
|
*m, "GaussianBlurOperation")
|
||||||
.def(py::init([](std::vector<int32_t> kernel_size, std::vector<float> sigma) {
|
.def(py::init([](const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma) {
|
||||||
auto gaussian_blur = std::make_shared<vision::GaussianBlurOperation>(kernel_size, sigma);
|
auto gaussian_blur = std::make_shared<vision::GaussianBlurOperation>(kernel_size, sigma);
|
||||||
THROW_IF_ERROR(gaussian_blur->ValidateParams());
|
THROW_IF_ERROR(gaussian_blur->ValidateParams());
|
||||||
return gaussian_blur;
|
return gaussian_blur;
|
||||||
|
@ -504,13 +505,15 @@ PYBIND_REGISTER(RandomResizeWithBBoxOperation, 1, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(RandomRotationOperation, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(
|
||||||
(void)py::class_<vision::RandomRotationOperation, TensorOperation,
|
RandomRotationOperation, 1, ([](const py::module *m) {
|
||||||
std::shared_ptr<vision::RandomRotationOperation>>(*m, "RandomRotationOperation")
|
(void)
|
||||||
.def(py::init([](std::vector<float> degrees, InterpolationMode interpolation_mode, bool expand,
|
py::class_<vision::RandomRotationOperation, TensorOperation, std::shared_ptr<vision::RandomRotationOperation>>(
|
||||||
std::vector<float> center, std::vector<uint8_t> fill_value) {
|
*m, "RandomRotationOperation")
|
||||||
auto random_rotation = std::make_shared<vision::RandomRotationOperation>(
|
.def(py::init([](const std::vector<float> °rees, InterpolationMode interpolation_mode, bool expand,
|
||||||
degrees, interpolation_mode, expand, center, fill_value);
|
const std::vector<float> ¢er, const std::vector<uint8_t> &fill_value) {
|
||||||
|
auto random_rotation =
|
||||||
|
std::make_shared<vision::RandomRotationOperation>(degrees, interpolation_mode, expand, center, fill_value);
|
||||||
THROW_IF_ERROR(random_rotation->ValidateParams());
|
THROW_IF_ERROR(random_rotation->ValidateParams());
|
||||||
return random_rotation;
|
return random_rotation;
|
||||||
}));
|
}));
|
||||||
|
@ -520,7 +523,7 @@ PYBIND_REGISTER(
|
||||||
RandomSelectSubpolicyOperation, 1, ([](const py::module *m) {
|
RandomSelectSubpolicyOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation,
|
(void)py::class_<vision::RandomSelectSubpolicyOperation, TensorOperation,
|
||||||
std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation")
|
std::shared_ptr<vision::RandomSelectSubpolicyOperation>>(*m, "RandomSelectSubpolicyOperation")
|
||||||
.def(py::init([](const py::list py_policy) {
|
.def(py::init([](const py::list &py_policy) {
|
||||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy;
|
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> cpp_policy;
|
||||||
for (auto &py_sub : py_policy) {
|
for (auto &py_sub : py_policy) {
|
||||||
cpp_policy.push_back({});
|
cpp_policy.push_back({});
|
||||||
|
@ -643,8 +646,8 @@ PYBIND_REGISTER(RgbToBgrOperation, 1, ([](const py::module *m) {
|
||||||
PYBIND_REGISTER(RotateOperation, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(RotateOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<vision::RotateOperation, TensorOperation, std::shared_ptr<vision::RotateOperation>>(
|
(void)py::class_<vision::RotateOperation, TensorOperation, std::shared_ptr<vision::RotateOperation>>(
|
||||||
*m, "RotateOperation")
|
*m, "RotateOperation")
|
||||||
.def(py::init([](float degrees, InterpolationMode resample, bool expand, std::vector<float> center,
|
.def(py::init([](float degrees, InterpolationMode resample, bool expand,
|
||||||
std::vector<uint8_t> fill_value) {
|
const std::vector<float> ¢er, const std::vector<uint8_t> &fill_value) {
|
||||||
auto rotate =
|
auto rotate =
|
||||||
std::make_shared<vision::RotateOperation>(degrees, resample, expand, center, fill_value);
|
std::make_shared<vision::RotateOperation>(degrees, resample, expand, center, fill_value);
|
||||||
THROW_IF_ERROR(rotate->ValidateParams());
|
THROW_IF_ERROR(rotate->ValidateParams());
|
||||||
|
@ -694,7 +697,7 @@ PYBIND_REGISTER(
|
||||||
UniformAugOperation, 1, ([](const py::module *m) {
|
UniformAugOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<vision::UniformAugOperation, TensorOperation, std::shared_ptr<vision::UniformAugOperation>>(
|
(void)py::class_<vision::UniformAugOperation, TensorOperation, std::shared_ptr<vision::UniformAugOperation>>(
|
||||||
*m, "UniformAugOperation")
|
*m, "UniformAugOperation")
|
||||||
.def(py::init([](const py::list transforms, int32_t num_ops) {
|
.def(py::init([](const py::list &transforms, int32_t num_ops) {
|
||||||
auto uniform_aug =
|
auto uniform_aug =
|
||||||
std::make_shared<vision::UniformAugOperation>(std::move(toTensorOperations(transforms)), num_ops);
|
std::make_shared<vision::UniformAugOperation>(std::move(toTensorOperations(transforms)), num_ops);
|
||||||
THROW_IF_ERROR(uniform_aug->ValidateParams());
|
THROW_IF_ERROR(uniform_aug->ValidateParams());
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -15,9 +15,11 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "minddata/dataset/include/dataset/samplers.h"
|
#include "minddata/dataset/include/dataset/samplers.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/pk_sampler_ir.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/pk_sampler_ir.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/prebuilt_sampler_ir.h"
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/random_sampler_ir.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/random_sampler_ir.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h"
|
||||||
|
@ -27,7 +29,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// DistributedSampler
|
// DistributedSampler
|
||||||
DistributedSampler::DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
DistributedSampler::DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||||
uint32_t seed, int64_t offset, bool even_dist)
|
uint32_t seed, int64_t offset, bool even_dist)
|
||||||
|
@ -69,7 +70,7 @@ std::shared_ptr<SamplerObj> SequentialSampler::Parse() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubsetSampler
|
// SubsetSampler
|
||||||
SubsetSampler::SubsetSampler(std::vector<int64_t> indices, int64_t num_samples)
|
SubsetSampler::SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples)
|
||||||
: indices_(indices), num_samples_(num_samples) {}
|
: indices_(indices), num_samples_(num_samples) {}
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> SubsetSampler::Parse() const {
|
std::shared_ptr<SamplerObj> SubsetSampler::Parse() const {
|
||||||
|
@ -77,7 +78,7 @@ std::shared_ptr<SamplerObj> SubsetSampler::Parse() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubsetRandomSampler
|
// SubsetRandomSampler
|
||||||
SubsetRandomSampler::SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples)
|
SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples)
|
||||||
: SubsetSampler(indices, num_samples) {}
|
: SubsetSampler(indices, num_samples) {}
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> SubsetRandomSampler::Parse() const {
|
std::shared_ptr<SamplerObj> SubsetRandomSampler::Parse() const {
|
||||||
|
@ -85,12 +86,11 @@ std::shared_ptr<SamplerObj> SubsetRandomSampler::Parse() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WeightedRandomSampler
|
// WeightedRandomSampler
|
||||||
WeightedRandomSampler::WeightedRandomSampler(std::vector<double> weights, int64_t num_samples, bool replacement)
|
WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement)
|
||||||
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
|
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> WeightedRandomSampler::Parse() const {
|
std::shared_ptr<SamplerObj> WeightedRandomSampler::Parse() const {
|
||||||
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,23 +13,21 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include "minddata/dataset/include/dataset/text.h"
|
||||||
|
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
#include "utils/file_utils.h"
|
|
||||||
#include "minddata/dataset/include/dataset/text.h"
|
|
||||||
#include "minddata/dataset/core/type_id.h"
|
#include "minddata/dataset/core/type_id.h"
|
||||||
#include "minddata/dataset/text/ir/kernels/text_ir.h"
|
#include "minddata/dataset/text/ir/kernels/text_ir.h"
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
#include "mindspore/core/ir/dtype/type_id.h"
|
||||||
|
#include "utils/file_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Transform operations for text.
|
// Transform operations for text.
|
||||||
namespace text {
|
namespace text {
|
||||||
|
|
||||||
constexpr size_t size_two = 2;
|
constexpr size_t size_two = 2;
|
||||||
constexpr size_t size_three = 3;
|
constexpr size_t size_three = 3;
|
||||||
constexpr int64_t value_one = 1;
|
constexpr int64_t value_one = 1;
|
||||||
|
@ -104,7 +102,7 @@ std::shared_ptr<TensorOperation> BertTokenizer::Parse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CaseFold
|
// CaseFold
|
||||||
CaseFold::CaseFold() {}
|
CaseFold::CaseFold() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> CaseFold::Parse() { return std::make_shared<CaseFoldOperation>(); }
|
std::shared_ptr<TensorOperation> CaseFold::Parse() { return std::make_shared<CaseFoldOperation>(); }
|
||||||
#endif
|
#endif
|
||||||
|
@ -170,6 +168,7 @@ Status JiebaTokenizer::AddDictChar(const std::vector<char> &file_path) {
|
||||||
|
|
||||||
Status JiebaTokenizer::ParserFile(const std::string &file_path,
|
Status JiebaTokenizer::ParserFile(const std::string &file_path,
|
||||||
std::vector<std::pair<std::string, int64_t>> *const user_dict) {
|
std::vector<std::pair<std::string, int64_t>> *const user_dict) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(user_dict);
|
||||||
auto realpath = FileUtils::GetRealPath(file_path.data());
|
auto realpath = FileUtils::GetRealPath(file_path.data());
|
||||||
if (!realpath.has_value()) {
|
if (!realpath.has_value()) {
|
||||||
std::string err_msg = "Get real path failed, path: " + file_path;
|
std::string err_msg = "Get real path failed, path: " + file_path;
|
||||||
|
@ -193,7 +192,7 @@ Status JiebaTokenizer::ParserFile(const std::string &file_path,
|
||||||
if (tokens.size() == size_two) {
|
if (tokens.size() == size_two) {
|
||||||
(void)user_dict->emplace_back(tokens.str(value_one), 0);
|
(void)user_dict->emplace_back(tokens.str(value_one), 0);
|
||||||
} else if (tokens.size() == size_three) {
|
} else if (tokens.size() == size_three) {
|
||||||
(void)user_dict->emplace_back(tokens.str(value_one), strtoll(tokens.str(value_two).c_str(), NULL, 0));
|
(void)user_dict->emplace_back(tokens.str(value_one), strtoll(tokens.str(value_two).c_str(), nullptr, 0));
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -201,6 +200,7 @@ Status JiebaTokenizer::ParserFile(const std::string &file_path,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ifs.close();
|
||||||
MS_LOG(INFO) << "JiebaTokenizer::AddDict: The size of user input dictionary is: " << user_dict->size();
|
MS_LOG(INFO) << "JiebaTokenizer::AddDict: The size of user input dictionary is: " << user_dict->size();
|
||||||
MS_LOG(INFO) << "Valid rows in input dictionary (Maximum of first 10 rows are shown.):";
|
MS_LOG(INFO) << "Valid rows in input dictionary (Maximum of first 10 rows are shown.):";
|
||||||
for (std::size_t i = 0; i != user_dict->size(); ++i) {
|
for (std::size_t i = 0; i != user_dict->size(); ++i) {
|
||||||
|
@ -367,7 +367,8 @@ struct ToVectors::Data {
|
||||||
bool lower_case_backup_;
|
bool lower_case_backup_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ToVectors::ToVectors(const std::shared_ptr<Vectors> &vectors, const std::vector<float> unk_init, bool lower_case_backup)
|
ToVectors::ToVectors(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init,
|
||||||
|
bool lower_case_backup)
|
||||||
: data_(std::make_shared<Data>(vectors, unk_init, lower_case_backup)) {}
|
: data_(std::make_shared<Data>(vectors, unk_init, lower_case_backup)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> ToVectors::Parse() {
|
std::shared_ptr<TensorOperation> ToVectors::Parse() {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,17 +18,14 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "mindspore/ccsrc/minddata/dataset/core/type_id.h"
|
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
|
||||||
#include "minddata/dataset/core/type_id.h"
|
#include "minddata/dataset/core/type_id.h"
|
||||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||||
|
#include "mindspore/core/ir/dtype/type_id.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Transform operations for data.
|
// Transform operations for data.
|
||||||
namespace transforms {
|
namespace transforms {
|
||||||
|
|
||||||
// API CLASS FOR DATA TRANSFORM OPERATIONS
|
// API CLASS FOR DATA TRANSFORM OPERATIONS
|
||||||
// (In alphabetical order)
|
// (In alphabetical order)
|
||||||
|
|
||||||
|
@ -46,7 +43,7 @@ Compose::Compose(const std::vector<TensorTransform *> &transforms) : data_(std::
|
||||||
|
|
||||||
Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
|
Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
|
||||||
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
||||||
[](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
|
[](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
|
||||||
return op != nullptr ? op->Parse() : nullptr;
|
return op != nullptr ? op->Parse() : nullptr;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -92,7 +89,7 @@ std::shared_ptr<TensorOperation> Concatenate::Parse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Constructor to Duplicate
|
// Constructor to Duplicate
|
||||||
Duplicate::Duplicate() {}
|
Duplicate::Duplicate() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
|
std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
|
||||||
|
|
||||||
|
@ -202,7 +199,7 @@ RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, doubl
|
||||||
RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
|
RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
|
||||||
: data_(std::make_shared<Data>()) {
|
: data_(std::make_shared<Data>()) {
|
||||||
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
||||||
[](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
|
[](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
|
||||||
return op != nullptr ? op->Parse() : nullptr;
|
return op != nullptr ? op->Parse() : nullptr;
|
||||||
});
|
});
|
||||||
data_->prob_ = prob;
|
data_->prob_ = prob;
|
||||||
|
@ -234,7 +231,7 @@ RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) : d
|
||||||
RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
|
RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
|
||||||
: data_(std::make_shared<Data>()) {
|
: data_(std::make_shared<Data>()) {
|
||||||
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
||||||
[](const std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
|
[](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
|
||||||
return op != nullptr ? op->Parse() : nullptr;
|
return op != nullptr ? op->Parse() : nullptr;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -278,7 +275,7 @@ TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>
|
||||||
std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
|
std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
|
||||||
|
|
||||||
// Constructor to Unique
|
// Constructor to Unique
|
||||||
Unique::Unique() {}
|
Unique::Unique() = default;
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
|
std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -20,7 +20,6 @@
|
||||||
#include "minddata/dataset/kernels/ir/vision/ascend_vision_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/ascend_vision_ir.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "minddata/dataset/include/dataset/transforms.h"
|
|
||||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
|
||||||
|
@ -88,11 +87,8 @@
|
||||||
#endif
|
#endif
|
||||||
#include "minddata/dataset/kernels/ir/validators.h"
|
#include "minddata/dataset/kernels/ir/validators.h"
|
||||||
|
|
||||||
// Kernel image headers (in alphabetical order)
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Transform operations for computer vision.
|
// Transform operations for computer vision.
|
||||||
namespace vision {
|
namespace vision {
|
||||||
// CONSTRUCTORS FOR API CLASSES TO CREATE VISION TENSOR TRANSFORM OPERATIONS
|
// CONSTRUCTORS FOR API CLASSES TO CREATE VISION TENSOR TRANSFORM OPERATIONS
|
||||||
|
@ -163,7 +159,7 @@ struct AutoContrast::Data {
|
||||||
std::vector<uint32_t> ignore_;
|
std::vector<uint32_t> ignore_;
|
||||||
};
|
};
|
||||||
|
|
||||||
AutoContrast::AutoContrast(float cutoff, std::vector<uint32_t> ignore)
|
AutoContrast::AutoContrast(float cutoff, const std::vector<uint32_t> &ignore)
|
||||||
: data_(std::make_shared<Data>(cutoff, ignore)) {}
|
: data_(std::make_shared<Data>(cutoff, ignore)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> AutoContrast::Parse() {
|
std::shared_ptr<TensorOperation> AutoContrast::Parse() {
|
||||||
|
@ -187,7 +183,7 @@ BoundingBoxAugment::BoundingBoxAugment(const std::shared_ptr<TensorTransform> &t
|
||||||
data_->ratio_ = ratio;
|
data_->ratio_ = ratio;
|
||||||
}
|
}
|
||||||
|
|
||||||
BoundingBoxAugment::BoundingBoxAugment(const std::reference_wrapper<TensorTransform> transform, float ratio)
|
BoundingBoxAugment::BoundingBoxAugment(const std::reference_wrapper<TensorTransform> &transform, float ratio)
|
||||||
: data_(std::make_shared<Data>()) {
|
: data_(std::make_shared<Data>()) {
|
||||||
data_->transform_ = transform.get().Parse();
|
data_->transform_ = transform.get().Parse();
|
||||||
data_->ratio_ = ratio;
|
data_->ratio_ = ratio;
|
||||||
|
@ -204,7 +200,7 @@ struct CenterCrop::Data {
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CenterCrop::CenterCrop(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {}
|
CenterCrop::CenterCrop(const std::vector<int32_t> &size) : data_(std::make_shared<Data>(size)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> CenterCrop::Parse() { return std::make_shared<CenterCropOperation>(data_->size_); }
|
std::shared_ptr<TensorOperation> CenterCrop::Parse() { return std::make_shared<CenterCropOperation>(data_->size_); }
|
||||||
|
|
||||||
|
@ -246,7 +242,7 @@ struct Crop::Data {
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Crop::Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size)
|
Crop::Crop(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size)
|
||||||
: data_(std::make_shared<Data>(coordinates, size)) {}
|
: data_(std::make_shared<Data>(coordinates, size)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Crop::Parse() {
|
std::shared_ptr<TensorOperation> Crop::Parse() {
|
||||||
|
@ -313,7 +309,8 @@ struct DvppDecodeResizeJpeg::Data {
|
||||||
std::vector<uint32_t> resize_;
|
std::vector<uint32_t> resize_;
|
||||||
};
|
};
|
||||||
|
|
||||||
DvppDecodeResizeJpeg::DvppDecodeResizeJpeg(std::vector<uint32_t> resize) : data_(std::make_shared<Data>(resize)) {}
|
DvppDecodeResizeJpeg::DvppDecodeResizeJpeg(const std::vector<uint32_t> &resize)
|
||||||
|
: data_(std::make_shared<Data>(resize)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> DvppDecodeResizeJpeg::Parse() {
|
std::shared_ptr<TensorOperation> DvppDecodeResizeJpeg::Parse() {
|
||||||
return std::make_shared<DvppDecodeResizeOperation>(data_->resize_);
|
return std::make_shared<DvppDecodeResizeOperation>(data_->resize_);
|
||||||
|
@ -334,7 +331,8 @@ struct DvppDecodeResizeCropJpeg::Data {
|
||||||
std::vector<uint32_t> resize_;
|
std::vector<uint32_t> resize_;
|
||||||
};
|
};
|
||||||
|
|
||||||
DvppDecodeResizeCropJpeg::DvppDecodeResizeCropJpeg(std::vector<uint32_t> crop, std::vector<uint32_t> resize)
|
DvppDecodeResizeCropJpeg::DvppDecodeResizeCropJpeg(const std::vector<uint32_t> &crop,
|
||||||
|
const std::vector<uint32_t> &resize)
|
||||||
: data_(std::make_shared<Data>(crop, resize)) {}
|
: data_(std::make_shared<Data>(crop, resize)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> DvppDecodeResizeCropJpeg::Parse() {
|
std::shared_ptr<TensorOperation> DvppDecodeResizeCropJpeg::Parse() {
|
||||||
|
@ -365,7 +363,7 @@ std::shared_ptr<TensorOperation> DvppDecodePng::Parse(const MapTargetDevice &env
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
// Equalize Transform Operation.
|
// Equalize Transform Operation.
|
||||||
Equalize::Equalize() {}
|
Equalize::Equalize() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); }
|
std::shared_ptr<TensorOperation> Equalize::Parse() { return std::make_shared<EqualizeOperation>(); }
|
||||||
#endif // not ENABLE_ANDROID
|
#endif // not ENABLE_ANDROID
|
||||||
|
@ -387,17 +385,17 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
// HorizontalFlip Transform Operation.
|
// HorizontalFlip Transform Operation.
|
||||||
HorizontalFlip::HorizontalFlip() {}
|
HorizontalFlip::HorizontalFlip() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); }
|
std::shared_ptr<TensorOperation> HorizontalFlip::Parse() { return std::make_shared<HorizontalFlipOperation>(); }
|
||||||
|
|
||||||
// HwcToChw Transform Operation.
|
// HwcToChw Transform Operation.
|
||||||
HWC2CHW::HWC2CHW() {}
|
HWC2CHW::HWC2CHW() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> HWC2CHW::Parse() { return std::make_shared<HwcToChwOperation>(); }
|
std::shared_ptr<TensorOperation> HWC2CHW::Parse() { return std::make_shared<HwcToChwOperation>(); }
|
||||||
|
|
||||||
// Invert Transform Operation.
|
// Invert Transform Operation.
|
||||||
Invert::Invert() {}
|
Invert::Invert() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Invert::Parse() { return std::make_shared<InvertOperation>(); }
|
std::shared_ptr<TensorOperation> Invert::Parse() { return std::make_shared<InvertOperation>(); }
|
||||||
|
|
||||||
|
@ -419,7 +417,8 @@ struct Normalize::Data {
|
||||||
std::vector<float> std_;
|
std::vector<float> std_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Normalize::Normalize(std::vector<float> mean, std::vector<float> std) : data_(std::make_shared<Data>(mean, std)) {}
|
Normalize::Normalize(const std::vector<float> &mean, const std::vector<float> &std)
|
||||||
|
: data_(std::make_shared<Data>(mean, std)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Normalize::Parse() {
|
std::shared_ptr<TensorOperation> Normalize::Parse() {
|
||||||
return std::make_shared<NormalizeOperation>(data_->mean_, data_->std_);
|
return std::make_shared<NormalizeOperation>(data_->mean_, data_->std_);
|
||||||
|
@ -464,7 +463,7 @@ struct Pad::Data {
|
||||||
BorderType padding_mode_;
|
BorderType padding_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Pad::Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
|
Pad::Pad(const std::vector<int32_t> &padding, const std::vector<uint8_t> &fill_value, BorderType padding_mode)
|
||||||
: data_(std::make_shared<Data>(padding, fill_value, padding_mode)) {}
|
: data_(std::make_shared<Data>(padding, fill_value, padding_mode)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Pad::Parse() {
|
std::shared_ptr<TensorOperation> Pad::Parse() {
|
||||||
|
@ -524,7 +523,7 @@ struct RandomAutoContrast::Data {
|
||||||
float probability_;
|
float probability_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomAutoContrast::RandomAutoContrast(float cutoff, std::vector<uint32_t> ignore, float prob)
|
RandomAutoContrast::RandomAutoContrast(float cutoff, const std::vector<uint32_t> &ignore, float prob)
|
||||||
: data_(std::make_shared<Data>(cutoff, ignore, prob)) {}
|
: data_(std::make_shared<Data>(cutoff, ignore, prob)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomAutoContrast::Parse() {
|
std::shared_ptr<TensorOperation> RandomAutoContrast::Parse() {
|
||||||
|
@ -555,8 +554,8 @@ struct RandomColorAdjust::Data {
|
||||||
std::vector<float> hue_;
|
std::vector<float> hue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomColorAdjust::RandomColorAdjust(std::vector<float> brightness, std::vector<float> contrast,
|
RandomColorAdjust::RandomColorAdjust(const std::vector<float> &brightness, const std::vector<float> &contrast,
|
||||||
std::vector<float> saturation, std::vector<float> hue)
|
const std::vector<float> &saturation, const std::vector<float> &hue)
|
||||||
: data_(std::make_shared<Data>(brightness, contrast, saturation, hue)) {}
|
: data_(std::make_shared<Data>(brightness, contrast, saturation, hue)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomColorAdjust::Parse() {
|
std::shared_ptr<TensorOperation> RandomColorAdjust::Parse() {
|
||||||
|
@ -580,8 +579,8 @@ struct RandomCrop::Data {
|
||||||
BorderType padding_mode_;
|
BorderType padding_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomCrop::RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
|
RandomCrop::RandomCrop(const std::vector<int32_t> &size, const std::vector<int32_t> &padding, bool pad_if_needed,
|
||||||
std::vector<uint8_t> fill_value, BorderType padding_mode)
|
const std::vector<uint8_t> &fill_value, BorderType padding_mode)
|
||||||
: data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {}
|
: data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomCrop::Parse() {
|
std::shared_ptr<TensorOperation> RandomCrop::Parse() {
|
||||||
|
@ -601,8 +600,8 @@ struct RandomCropDecodeResize::Data {
|
||||||
int32_t max_attempts_;
|
int32_t max_attempts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomCropDecodeResize::RandomCropDecodeResize(std::vector<int32_t> size, std::vector<float> scale,
|
RandomCropDecodeResize::RandomCropDecodeResize(const std::vector<int32_t> &size, const std::vector<float> &scale,
|
||||||
std::vector<float> ratio, InterpolationMode interpolation,
|
const std::vector<float> &ratio, InterpolationMode interpolation,
|
||||||
int32_t max_attempts)
|
int32_t max_attempts)
|
||||||
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
|
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
|
||||||
|
|
||||||
|
@ -627,8 +626,9 @@ struct RandomCropWithBBox::Data {
|
||||||
BorderType padding_mode_;
|
BorderType padding_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomCropWithBBox::RandomCropWithBBox(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
|
RandomCropWithBBox::RandomCropWithBBox(const std::vector<int32_t> &size, const std::vector<int32_t> &padding,
|
||||||
std::vector<uint8_t> fill_value, BorderType padding_mode)
|
bool pad_if_needed, const std::vector<uint8_t> &fill_value,
|
||||||
|
BorderType padding_mode)
|
||||||
: data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {}
|
: data_(std::make_shared<Data>(size, padding, pad_if_needed, fill_value, padding_mode)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomCropWithBBox::Parse() {
|
std::shared_ptr<TensorOperation> RandomCropWithBBox::Parse() {
|
||||||
|
@ -714,7 +714,7 @@ struct RandomResize::Data {
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomResize::RandomResize(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {}
|
RandomResize::RandomResize(const std::vector<int32_t> &size) : data_(std::make_shared<Data>(size)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomResize::Parse() { return std::make_shared<RandomResizeOperation>(data_->size_); }
|
std::shared_ptr<TensorOperation> RandomResize::Parse() { return std::make_shared<RandomResizeOperation>(data_->size_); }
|
||||||
|
|
||||||
|
@ -724,7 +724,7 @@ struct RandomResizeWithBBox::Data {
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomResizeWithBBox::RandomResizeWithBBox(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {}
|
RandomResizeWithBBox::RandomResizeWithBBox(const std::vector<int32_t> &size) : data_(std::make_shared<Data>(size)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomResizeWithBBox::Parse() {
|
std::shared_ptr<TensorOperation> RandomResizeWithBBox::Parse() {
|
||||||
return std::make_shared<RandomResizeWithBBoxOperation>(data_->size_);
|
return std::make_shared<RandomResizeWithBBoxOperation>(data_->size_);
|
||||||
|
@ -742,8 +742,9 @@ struct RandomResizedCrop::Data {
|
||||||
int32_t max_attempts_;
|
int32_t max_attempts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomResizedCrop::RandomResizedCrop(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio,
|
RandomResizedCrop::RandomResizedCrop(const std::vector<int32_t> &size, const std::vector<float> &scale,
|
||||||
InterpolationMode interpolation, int32_t max_attempts)
|
const std::vector<float> &ratio, InterpolationMode interpolation,
|
||||||
|
int32_t max_attempts)
|
||||||
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
|
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomResizedCrop::Parse() {
|
std::shared_ptr<TensorOperation> RandomResizedCrop::Parse() {
|
||||||
|
@ -763,8 +764,8 @@ struct RandomResizedCropWithBBox::Data {
|
||||||
int32_t max_attempts_;
|
int32_t max_attempts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomResizedCropWithBBox::RandomResizedCropWithBBox(std::vector<int32_t> size, std::vector<float> scale,
|
RandomResizedCropWithBBox::RandomResizedCropWithBBox(const std::vector<int32_t> &size, const std::vector<float> &scale,
|
||||||
std::vector<float> ratio, InterpolationMode interpolation,
|
const std::vector<float> &ratio, InterpolationMode interpolation,
|
||||||
int32_t max_attempts)
|
int32_t max_attempts)
|
||||||
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
|
: data_(std::make_shared<Data>(size, scale, ratio, interpolation, max_attempts)) {}
|
||||||
|
|
||||||
|
@ -785,8 +786,8 @@ struct RandomRotation::Data {
|
||||||
std::vector<uint8_t> fill_value_;
|
std::vector<uint8_t> fill_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomRotation::RandomRotation(std::vector<float> degrees, InterpolationMode resample, bool expand,
|
RandomRotation::RandomRotation(const std::vector<float> °rees, InterpolationMode resample, bool expand,
|
||||||
std::vector<float> center, std::vector<uint8_t> fill_value)
|
const std::vector<float> ¢er, const std::vector<uint8_t> &fill_value)
|
||||||
: data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {}
|
: data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomRotation::Parse() {
|
std::shared_ptr<TensorOperation> RandomRotation::Parse() {
|
||||||
|
@ -857,7 +858,7 @@ struct RandomSharpness::Data {
|
||||||
std::vector<float> degrees_;
|
std::vector<float> degrees_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomSharpness::RandomSharpness(std::vector<float> degrees) : data_(std::make_shared<Data>(degrees)) {}
|
RandomSharpness::RandomSharpness(const std::vector<float> °rees) : data_(std::make_shared<Data>(degrees)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomSharpness::Parse() {
|
std::shared_ptr<TensorOperation> RandomSharpness::Parse() {
|
||||||
return std::make_shared<RandomSharpnessOperation>(data_->degrees_);
|
return std::make_shared<RandomSharpnessOperation>(data_->degrees_);
|
||||||
|
@ -869,7 +870,7 @@ struct RandomSolarize::Data {
|
||||||
std::vector<uint8_t> threshold_;
|
std::vector<uint8_t> threshold_;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandomSolarize::RandomSolarize(std::vector<uint8_t> threshold) : data_(std::make_shared<Data>(threshold)) {}
|
RandomSolarize::RandomSolarize(const std::vector<uint8_t> &threshold) : data_(std::make_shared<Data>(threshold)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RandomSolarize::Parse() {
|
std::shared_ptr<TensorOperation> RandomSolarize::Parse() {
|
||||||
return std::make_shared<RandomSolarizeOperation>(data_->threshold_);
|
return std::make_shared<RandomSolarizeOperation>(data_->threshold_);
|
||||||
|
@ -921,7 +922,7 @@ struct Resize::Data {
|
||||||
InterpolationMode interpolation_;
|
InterpolationMode interpolation_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Resize::Resize(std::vector<int32_t> size, InterpolationMode interpolation)
|
Resize::Resize(const std::vector<int32_t> &size, InterpolationMode interpolation)
|
||||||
: data_(std::make_shared<Data>(size, interpolation)) {}
|
: data_(std::make_shared<Data>(size, interpolation)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Resize::Parse() {
|
std::shared_ptr<TensorOperation> Resize::Parse() {
|
||||||
|
@ -960,6 +961,40 @@ std::shared_ptr<TensorOperation> ResizePreserveAR::Parse() {
|
||||||
return std::make_shared<ResizePreserveAROperation>(data_->height_, data_->width_, data_->img_orientation_);
|
return std::make_shared<ResizePreserveAROperation>(data_->height_, data_->width_, data_->img_orientation_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
// ResizeWithBBox Transform Operation.
|
||||||
|
struct ResizeWithBBox::Data {
|
||||||
|
Data(const std::vector<int32_t> &size, InterpolationMode interpolation)
|
||||||
|
: size_(size), interpolation_(interpolation) {}
|
||||||
|
std::vector<int32_t> size_;
|
||||||
|
InterpolationMode interpolation_;
|
||||||
|
};
|
||||||
|
|
||||||
|
ResizeWithBBox::ResizeWithBBox(const std::vector<int32_t> &size, InterpolationMode interpolation)
|
||||||
|
: data_(std::make_shared<Data>(size, interpolation)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> ResizeWithBBox::Parse() {
|
||||||
|
return std::make_shared<ResizeWithBBoxOperation>(data_->size_, data_->interpolation_);
|
||||||
|
}
|
||||||
|
#endif // not ENABLE_ANDROID
|
||||||
|
|
||||||
|
// RGB2BGR Transform Operation.
|
||||||
|
std::shared_ptr<TensorOperation> RGB2BGR::Parse() { return std::make_shared<RgbToBgrOperation>(); }
|
||||||
|
|
||||||
|
// RGB2GRAY Transform Operation.
|
||||||
|
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
// RgbaToBgr Transform Operation.
|
||||||
|
RGBA2BGR::RGBA2BGR() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> RGBA2BGR::Parse() { return std::make_shared<RgbaToBgrOperation>(); }
|
||||||
|
|
||||||
|
// RgbaToRgb Transform Operation.
|
||||||
|
RGBA2RGB::RGBA2RGB() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> RGBA2RGB::Parse() { return std::make_shared<RgbaToRgbOperation>(); }
|
||||||
|
|
||||||
// Rotate Transform Operation.
|
// Rotate Transform Operation.
|
||||||
struct Rotate::Data {
|
struct Rotate::Data {
|
||||||
Data(const float °rees, InterpolationMode resample, bool expand, const std::vector<float> ¢er,
|
Data(const float °rees, InterpolationMode resample, bool expand, const std::vector<float> ¢er,
|
||||||
|
@ -977,8 +1012,8 @@ struct Rotate::Data {
|
||||||
|
|
||||||
Rotate::Rotate(FixRotationAngle angle_id) : data_(std::make_shared<Data>(angle_id)) {}
|
Rotate::Rotate(FixRotationAngle angle_id) : data_(std::make_shared<Data>(angle_id)) {}
|
||||||
|
|
||||||
Rotate::Rotate(float degrees, InterpolationMode resample, bool expand, std::vector<float> center,
|
Rotate::Rotate(float degrees, InterpolationMode resample, bool expand, const std::vector<float> ¢er,
|
||||||
std::vector<uint8_t> fill_value)
|
const std::vector<uint8_t> &fill_value)
|
||||||
: data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {}
|
: data_(std::make_shared<Data>(degrees, resample, expand, center, fill_value)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Rotate::Parse() {
|
std::shared_ptr<TensorOperation> Rotate::Parse() {
|
||||||
|
@ -997,40 +1032,6 @@ std::shared_ptr<TensorOperation> Rotate::Parse() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
// ResizeWithBBox Transform Operation.
|
|
||||||
struct ResizeWithBBox::Data {
|
|
||||||
Data(const std::vector<int32_t> &size, InterpolationMode interpolation)
|
|
||||||
: size_(size), interpolation_(interpolation) {}
|
|
||||||
std::vector<int32_t> size_;
|
|
||||||
InterpolationMode interpolation_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ResizeWithBBox::ResizeWithBBox(std::vector<int32_t> size, InterpolationMode interpolation)
|
|
||||||
: data_(std::make_shared<Data>(size, interpolation)) {}
|
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> ResizeWithBBox::Parse() {
|
|
||||||
return std::make_shared<ResizeWithBBoxOperation>(data_->size_, data_->interpolation_);
|
|
||||||
}
|
|
||||||
#endif // not ENABLE_ANDROID
|
|
||||||
|
|
||||||
// RGB2BGR Transform Operation.
|
|
||||||
std::shared_ptr<TensorOperation> RGB2BGR::Parse() { return std::make_shared<RgbToBgrOperation>(); }
|
|
||||||
|
|
||||||
// RGB2GRAY Transform Operation.
|
|
||||||
std::shared_ptr<TensorOperation> RGB2GRAY::Parse() { return std::make_shared<RgbToGrayOperation>(); }
|
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
// RgbaToBgr Transform Operation.
|
|
||||||
RGBA2BGR::RGBA2BGR() {}
|
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RGBA2BGR::Parse() { return std::make_shared<RgbaToBgrOperation>(); }
|
|
||||||
|
|
||||||
// RgbaToRgb Transform Operation.
|
|
||||||
RGBA2RGB::RGBA2RGB() {}
|
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> RGBA2RGB::Parse() { return std::make_shared<RgbaToRgbOperation>(); }
|
|
||||||
|
|
||||||
// SlicePatches Transform Operation.
|
// SlicePatches Transform Operation.
|
||||||
struct SlicePatches::Data {
|
struct SlicePatches::Data {
|
||||||
Data(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value)
|
Data(int32_t num_height, int32_t num_width, SliceMode slice_mode, uint8_t fill_value)
|
||||||
|
@ -1060,9 +1061,10 @@ struct SoftDvppDecodeRandomCropResizeJpeg::Data {
|
||||||
int32_t max_attempts_;
|
int32_t max_attempts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
SoftDvppDecodeRandomCropResizeJpeg::SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size,
|
SoftDvppDecodeRandomCropResizeJpeg::SoftDvppDecodeRandomCropResizeJpeg(const std::vector<int32_t> &size,
|
||||||
std::vector<float> scale,
|
const std::vector<float> &scale,
|
||||||
std::vector<float> ratio, int32_t max_attempts)
|
const std::vector<float> &ratio,
|
||||||
|
int32_t max_attempts)
|
||||||
: data_(std::make_shared<Data>(size, scale, ratio, max_attempts)) {}
|
: data_(std::make_shared<Data>(size, scale, ratio, max_attempts)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> SoftDvppDecodeRandomCropResizeJpeg::Parse() {
|
std::shared_ptr<TensorOperation> SoftDvppDecodeRandomCropResizeJpeg::Parse() {
|
||||||
|
@ -1076,14 +1078,15 @@ struct SoftDvppDecodeResizeJpeg::Data {
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
SoftDvppDecodeResizeJpeg::SoftDvppDecodeResizeJpeg(std::vector<int32_t> size) : data_(std::make_shared<Data>(size)) {}
|
SoftDvppDecodeResizeJpeg::SoftDvppDecodeResizeJpeg(const std::vector<int32_t> &size)
|
||||||
|
: data_(std::make_shared<Data>(size)) {}
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> SoftDvppDecodeResizeJpeg::Parse() {
|
std::shared_ptr<TensorOperation> SoftDvppDecodeResizeJpeg::Parse() {
|
||||||
return std::make_shared<SoftDvppDecodeResizeJpegOperation>(data_->size_);
|
return std::make_shared<SoftDvppDecodeResizeJpegOperation>(data_->size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// SwapRedBlue Transform Operation.
|
// SwapRedBlue Transform Operation.
|
||||||
SwapRedBlue::SwapRedBlue() {}
|
SwapRedBlue::SwapRedBlue() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> SwapRedBlue::Parse() { return std::make_shared<SwapRedBlueOperation>(); }
|
std::shared_ptr<TensorOperation> SwapRedBlue::Parse() { return std::make_shared<SwapRedBlueOperation>(); }
|
||||||
|
|
||||||
|
@ -1104,7 +1107,7 @@ UniformAugment::UniformAugment(const std::vector<TensorTransform *> &transforms,
|
||||||
UniformAugment::UniformAugment(const std::vector<std::shared_ptr<TensorTransform>> &transforms, int32_t num_ops)
|
UniformAugment::UniformAugment(const std::vector<std::shared_ptr<TensorTransform>> &transforms, int32_t num_ops)
|
||||||
: data_(std::make_shared<Data>()) {
|
: data_(std::make_shared<Data>()) {
|
||||||
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
(void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
|
||||||
[](const std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
|
[](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
|
||||||
return op ? op->Parse() : nullptr;
|
return op ? op->Parse() : nullptr;
|
||||||
});
|
});
|
||||||
data_->num_ops_ = num_ops;
|
data_->num_ops_ = num_ops;
|
||||||
|
@ -1122,11 +1125,10 @@ std::shared_ptr<TensorOperation> UniformAugment::Parse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerticalFlip Transform Operation.
|
// VerticalFlip Transform Operation.
|
||||||
VerticalFlip::VerticalFlip() {}
|
VerticalFlip::VerticalFlip() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> VerticalFlip::Parse() { return std::make_shared<VerticalFlipOperation>(); }
|
std::shared_ptr<TensorOperation> VerticalFlip::Parse() { return std::make_shared<VerticalFlipOperation>(); }
|
||||||
#endif // not ENABLE_ANDROID
|
#endif // not ENABLE_ANDROID
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -34,7 +34,7 @@ constexpr char kAllpassBiquadOperation[] = "AllpassBiquad";
|
||||||
|
|
||||||
class AllpassBiquadOperation : public TensorOperation {
|
class AllpassBiquadOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q);
|
AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q);
|
||||||
|
|
||||||
~AllpassBiquadOperation() = default;
|
~AllpassBiquadOperation() = default;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,7 +28,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kAmplitudeToDBOperation[] = "AmplitudeToDB";
|
constexpr char kAmplitudeToDBOperation[] = "AmplitudeToDB";
|
||||||
|
|
||||||
class AmplitudeToDBOperation : public TensorOperation {
|
class AmplitudeToDBOperation : public TensorOperation {
|
||||||
|
@ -51,7 +50,6 @@ class AmplitudeToDBOperation : public TensorOperation {
|
||||||
float amin_;
|
float amin_;
|
||||||
float top_db_;
|
float top_db_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,7 +22,7 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
// AngleOperation
|
// AngleOperation
|
||||||
AngleOperation::AngleOperation() {}
|
AngleOperation::AngleOperation() = default;
|
||||||
|
|
||||||
Status AngleOperation::ValidateParams() { return Status::OK(); }
|
Status AngleOperation::ValidateParams() { return Status::OK(); }
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,7 +30,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kAngleOperation[] = "Angle";
|
constexpr char kAngleOperation[] = "Angle";
|
||||||
|
|
||||||
class AngleOperation : public TensorOperation {
|
class AngleOperation : public TensorOperation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,7 +28,7 @@ BandBiquadOperation::BandBiquadOperation(int32_t sample_rate, float central_freq
|
||||||
|
|
||||||
Status BandBiquadOperation::ValidateParams() {
|
Status BandBiquadOperation::ValidateParams() {
|
||||||
RETURN_IF_NOT_OK(ValidateScalar("BandBiquad", "Q", Q_, {0, 1.0}, true, false));
|
RETURN_IF_NOT_OK(ValidateScalar("BandBiquad", "Q", Q_, {0, 1.0}, true, false));
|
||||||
RETURN_IF_NOT_OK(ValidateScalarNotZero("BandBIquad", "sample_rate", sample_rate_));
|
RETURN_IF_NOT_OK(ValidateScalarNotZero("BandBiquad", "sample_rate", sample_rate_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,12 +30,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kBandBiquadOperation[] = "BandBiquad";
|
constexpr char kBandBiquadOperation[] = "BandBiquad";
|
||||||
|
|
||||||
class BandBiquadOperation : public TensorOperation {
|
class BandBiquadOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise);
|
BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise);
|
||||||
|
|
||||||
~BandBiquadOperation() = default;
|
~BandBiquadOperation() = default;
|
||||||
|
|
||||||
|
@ -53,7 +52,6 @@ class BandBiquadOperation : public TensorOperation {
|
||||||
float Q_;
|
float Q_;
|
||||||
bool noise_;
|
bool noise_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,12 +30,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kBandpassBiquadOperation[] = "BandpassBiquad";
|
constexpr char kBandpassBiquadOperation[] = "BandpassBiquad";
|
||||||
|
|
||||||
class BandpassBiquadOperation : public TensorOperation {
|
class BandpassBiquadOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain);
|
BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain);
|
||||||
|
|
||||||
~BandpassBiquadOperation() = default;
|
~BandpassBiquadOperation() = default;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,12 +30,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kBandrejectBiquadOperation[] = "BandrejectBiquad";
|
constexpr char kBandrejectBiquadOperation[] = "BandrejectBiquad";
|
||||||
|
|
||||||
class BandrejectBiquadOperation : public TensorOperation {
|
class BandrejectBiquadOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q);
|
BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q);
|
||||||
|
|
||||||
~BandrejectBiquadOperation() = default;
|
~BandrejectBiquadOperation() = default;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,12 +30,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kBassBiquadOperation[] = "BassBiquad";
|
constexpr char kBassBiquadOperation[] = "BassBiquad";
|
||||||
|
|
||||||
class BassBiquadOperation : public TensorOperation {
|
class BassBiquadOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q);
|
BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q);
|
||||||
|
|
||||||
~BassBiquadOperation() = default;
|
~BassBiquadOperation() = default;
|
||||||
|
|
||||||
|
@ -53,7 +52,6 @@ class BassBiquadOperation : public TensorOperation {
|
||||||
float central_freq_;
|
float central_freq_;
|
||||||
float Q_;
|
float Q_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,9 +29,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
namespace audio {
|
namespace audio {
|
||||||
constexpr char kBiquadOperation[] = "Biquad";
|
constexpr char kBiquadOperation[] = "Biquad";
|
||||||
|
|
||||||
class BiquadOperation : public TensorOperation {
|
class BiquadOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2);
|
BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,6 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_
|
||||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPLEX_NORM_IR_H_
|
||||||
|
|
||||||
|
@ -26,7 +27,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kComplexNormOperation[] = "ComplexNorm";
|
constexpr char kComplexNormOperation[] = "ComplexNorm";
|
||||||
|
|
||||||
class ComplexNormOperation : public TensorOperation {
|
class ComplexNormOperation : public TensorOperation {
|
||||||
|
@ -46,7 +46,6 @@ class ComplexNormOperation : public TensorOperation {
|
||||||
private:
|
private:
|
||||||
float power_;
|
float power_;
|
||||||
}; // class ComplexNormOperation
|
}; // class ComplexNormOperation
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,7 +28,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kContrastOperation[] = "Contrast";
|
constexpr char kContrastOperation[] = "Contrast";
|
||||||
|
|
||||||
class ContrastOperation : public TensorOperation {
|
class ContrastOperation : public TensorOperation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,7 +28,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kDCShiftOperation[] = "DCShift";
|
constexpr char kDCShiftOperation[] = "DCShift";
|
||||||
|
|
||||||
class DCShiftOperation : public TensorOperation {
|
class DCShiftOperation : public TensorOperation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,9 +28,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kDeemphBiquadOperation[] = "DeemphBiquad";
|
constexpr char kDeemphBiquadOperation[] = "DeemphBiquad";
|
||||||
|
|
||||||
class DeemphBiquadOperation : public TensorOperation {
|
class DeemphBiquadOperation : public TensorOperation {
|
||||||
|
@ -51,7 +49,6 @@ class DeemphBiquadOperation : public TensorOperation {
|
||||||
int32_t sample_rate_;
|
int32_t sample_rate_;
|
||||||
};
|
};
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,7 +17,6 @@
|
||||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||||
|
|
||||||
#include "minddata/dataset/audio/ir/validators.h"
|
#include "minddata/dataset/audio/ir/validators.h"
|
||||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
|
||||||
#include "minddata/dataset/audio/kernels/equalizer_biquad_op.h"
|
#include "minddata/dataset/audio/kernels/equalizer_biquad_op.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -35,7 +35,7 @@ std::shared_ptr<TensorOp> FadeOperation::Build() {
|
||||||
return tensor_op;
|
return tensor_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FadeOperation::to_json(nlohmann::json *const out_json) {
|
Status FadeOperation::to_json(nlohmann::json *out_json) {
|
||||||
nlohmann::json args;
|
nlohmann::json args;
|
||||||
args["fade_in_len"] = fade_in_len_;
|
args["fade_in_len"] = fade_in_len_;
|
||||||
args["fade_out_len"] = fade_out_len_;
|
args["fade_out_len"] = fade_out_len_;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -27,12 +27,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kFadeOperation[] = "Fade";
|
constexpr char kFadeOperation[] = "Fade";
|
||||||
|
|
||||||
class FadeOperation : public TensorOperation {
|
class FadeOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape);
|
FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape);
|
||||||
|
|
||||||
~FadeOperation() = default;
|
~FadeOperation() = default;
|
||||||
|
|
||||||
|
@ -45,7 +44,7 @@ class FadeOperation : public TensorOperation {
|
||||||
/// \brief Get the arguments of node
|
/// \brief Get the arguments of node
|
||||||
/// \param[out] out_json JSON string of all attributes
|
/// \param[out] out_json JSON string of all attributes
|
||||||
/// \return Status of the function
|
/// \return Status of the function
|
||||||
Status to_json(nlohmann::json *const out_json) override;
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t fade_in_len_;
|
int32_t fade_in_len_;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -15,8 +15,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||||
#include "minddata/dataset/audio/kernels/frequency_masking_op.h"
|
|
||||||
#include "minddata/dataset/audio/ir/validators.h"
|
#include "minddata/dataset/audio/ir/validators.h"
|
||||||
|
#include "minddata/dataset/audio/kernels/frequency_masking_op.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,7 +26,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kFrequencyMaskingOperation[] = "FrequencyMasking";
|
constexpr char kFrequencyMaskingOperation[] = "FrequencyMasking";
|
||||||
|
|
||||||
class FrequencyMaskingOperation : public TensorOperation {
|
class FrequencyMaskingOperation : public TensorOperation {
|
||||||
|
@ -49,7 +48,6 @@ class FrequencyMaskingOperation : public TensorOperation {
|
||||||
bool iid_masks_;
|
bool iid_masks_;
|
||||||
float mask_value_;
|
float mask_value_;
|
||||||
}; // class FrequencyMaskingOperation
|
}; // class FrequencyMaskingOperation
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,9 +29,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kHighpassBiquadOperation[] = "HighpassBiquad";
|
constexpr char kHighpassBiquadOperation[] = "HighpassBiquad";
|
||||||
|
|
||||||
class HighpassBiquadOperation : public TensorOperation {
|
class HighpassBiquadOperation : public TensorOperation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,7 +28,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
namespace audio {
|
namespace audio {
|
||||||
// Char arrays storing name of corresponding classes (in alphabetical order)
|
// Char arrays storing name of corresponding classes (in alphabetical order)
|
||||||
constexpr char kLFilterOperation[] = "LFilter";
|
constexpr char kLFilterOperation[] = "LFilter";
|
||||||
|
@ -53,7 +52,6 @@ class LFilterOperation : public TensorOperation {
|
||||||
bool clamp_;
|
bool clamp_;
|
||||||
};
|
};
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -52,7 +52,6 @@ class LowpassBiquadOperation : public TensorOperation {
|
||||||
float cutoff_freq_;
|
float cutoff_freq_;
|
||||||
float Q_;
|
float Q_;
|
||||||
}; // class LowpassBiquad
|
}; // class LowpassBiquad
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -27,9 +27,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kMagphaseOperation[] = "Magphase";
|
constexpr char kMagphaseOperation[] = "Magphase";
|
||||||
|
|
||||||
class MagphaseOperation : public TensorOperation {
|
class MagphaseOperation : public TensorOperation {
|
||||||
|
@ -51,7 +49,6 @@ class MagphaseOperation : public TensorOperation {
|
||||||
private:
|
private:
|
||||||
float power_;
|
float power_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,6 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||||
|
|
||||||
#include "minddata/dataset/audio/ir/validators.h"
|
#include "minddata/dataset/audio/ir/validators.h"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,7 +26,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kMuLawDecodingOperation[] = "MuLawDecoding";
|
constexpr char kMuLawDecodingOperation[] = "MuLawDecoding";
|
||||||
|
|
||||||
class MuLawDecodingOperation : public TensorOperation {
|
class MuLawDecodingOperation : public TensorOperation {
|
||||||
|
@ -46,9 +45,7 @@ class MuLawDecodingOperation : public TensorOperation {
|
||||||
private:
|
private:
|
||||||
int32_t quantization_channels_;
|
int32_t quantization_channels_;
|
||||||
}; // class MuLawDecodingOperation
|
}; // class MuLawDecodingOperation
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,7 +26,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kMuLawEncodingOperation[] = "MuLawEncoding";
|
constexpr char kMuLawEncodingOperation[] = "MuLawEncoding";
|
||||||
|
|
||||||
class MuLawEncodingOperation : public TensorOperation {
|
class MuLawEncodingOperation : public TensorOperation {
|
||||||
|
@ -49,5 +48,4 @@ class MuLawEncodingOperation : public TensorOperation {
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,7 +30,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kRiaaBiquadOperation[] = "RiaaBiquad";
|
constexpr char kRiaaBiquadOperation[] = "RiaaBiquad";
|
||||||
|
|
||||||
class RiaaBiquadOperation : public TensorOperation {
|
class RiaaBiquadOperation : public TensorOperation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,7 +26,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kTimeMaskingOperation[] = "TimeMasking";
|
constexpr char kTimeMaskingOperation[] = "TimeMasking";
|
||||||
|
|
||||||
class TimeMaskingOperation : public TensorOperation {
|
class TimeMaskingOperation : public TensorOperation {
|
||||||
|
@ -49,7 +48,6 @@ class TimeMaskingOperation : public TensorOperation {
|
||||||
int32_t mask_start_;
|
int32_t mask_start_;
|
||||||
float mask_value_;
|
float mask_value_;
|
||||||
}; // class TimeMaskingOperation
|
}; // class TimeMaskingOperation
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -26,7 +26,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kTimeStretchOperation[] = "TimeStretch";
|
constexpr char kTimeStretchOperation[] = "TimeStretch";
|
||||||
|
|
||||||
class TimeStretchOperation : public TensorOperation {
|
class TimeStretchOperation : public TensorOperation {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -27,7 +27,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
constexpr char kVolOperation[] = "Vol";
|
constexpr char kVolOperation[] = "Vol";
|
||||||
|
|
||||||
class VolOperation : public TensorOperation {
|
class VolOperation : public TensorOperation {
|
||||||
|
@ -48,7 +47,6 @@ class VolOperation : public TensorOperation {
|
||||||
float gain_;
|
float gain_;
|
||||||
GainType gain_type_;
|
GainType gain_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -56,7 +56,7 @@ class EnWik9Op : public TextFileOp {
|
||||||
/// \brief DatasetName name getter.
|
/// \brief DatasetName name getter.
|
||||||
/// \param[in] upper A bool to control if you need upper DatasetName.
|
/// \param[in] upper A bool to control if you need upper DatasetName.
|
||||||
/// \return DatasetName of the current Op.
|
/// \return DatasetName of the current Op.
|
||||||
virtual std::string DatasetName(bool upper = false) const { return upper ? "EnWik9" : "enwik9"; }
|
std::string DatasetName(bool upper = false) const override { return upper ? "EnWik9" : "enwik9"; }
|
||||||
|
|
||||||
/// \brief Reads a text file and loads the data into multiple TensorRows.
|
/// \brief Reads a text file and loads the data into multiple TensorRows.
|
||||||
/// \param[in] file The file to read.
|
/// \param[in] file The file to read.
|
||||||
|
@ -70,7 +70,7 @@ class EnWik9Op : public TextFileOp {
|
||||||
/// \brief Count number of rows in each file.
|
/// \brief Count number of rows in each file.
|
||||||
/// \param[in] file Txt file name.
|
/// \param[in] file Txt file name.
|
||||||
/// \return int64_t The total number of rows in file.
|
/// \return int64_t The total number of rows in file.
|
||||||
int64_t CountTotalRows(const std::string &file);
|
int64_t CountTotalRows(const std::string &file) override;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,8 +29,8 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
// Constructor for EnWik9Node
|
// Constructor for EnWik9Node
|
||||||
EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||||
int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache)
|
||||||
: NonMappableSourceNode(std::move(cache)),
|
: NonMappableSourceNode(cache),
|
||||||
num_samples_(num_samples),
|
num_samples_(num_samples),
|
||||||
shuffle_(shuffle),
|
shuffle_(shuffle),
|
||||||
num_shards_(num_shards),
|
num_shards_(num_shards),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -37,7 +37,7 @@ class EnWik9Node : public NonMappableSourceNode {
|
||||||
/// \param[in] shard_id The id of shard.
|
/// \param[in] shard_id The id of shard.
|
||||||
/// \param[in] cache Tensor cache to use.
|
/// \param[in] cache Tensor cache to use.
|
||||||
EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||||
int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~EnWik9Node() = default;
|
~EnWik9Node() = default;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -31,12 +31,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
class TensorOperation;
|
class TensorOperation;
|
||||||
|
|
||||||
// Transform operations for performing computer audio.
|
// Transform operations for performing computer audio.
|
||||||
namespace audio {
|
namespace audio {
|
||||||
|
|
||||||
/// \brief Compute the angle of complex tensor input.
|
/// \brief Compute the angle of complex tensor input.
|
||||||
class MS_API Angle final : public TensorTransform {
|
class MS_API Angle final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
|
@ -51,29 +49,6 @@ class MS_API Angle final : public TensorTransform {
|
||||||
std::shared_ptr<TensorOperation> Parse() override;
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Design two-pole band filter.
|
|
||||||
class MS_API BandBiquad final : public TensorTransform {
|
|
||||||
public:
|
|
||||||
/// \brief Constructor.
|
|
||||||
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
|
||||||
/// \param[in] central_freq Central frequency (in Hz).
|
|
||||||
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
|
|
||||||
/// \param[in] noise Choose alternate mode for un-pitched audio or mode oriented to pitched audio(Default: False).
|
|
||||||
explicit BandBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool noise = false);
|
|
||||||
|
|
||||||
/// \brief Destructor.
|
|
||||||
~BandBiquad() = default;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
|
||||||
/// \return Shared pointer to TensorOperation object.
|
|
||||||
std::shared_ptr<TensorOperation> Parse() override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct Data;
|
|
||||||
std::shared_ptr<Data> data_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief Design two-pole allpass filter. Similar to SoX implementation.
|
/// \brief Design two-pole allpass filter. Similar to SoX implementation.
|
||||||
class MS_API AllpassBiquad final : public TensorTransform {
|
class MS_API AllpassBiquad final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
|
@ -121,6 +96,29 @@ class MS_API AmplitudeToDB final : public TensorTransform {
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief Design two-pole band filter.
|
||||||
|
class MS_API BandBiquad final : public TensorTransform {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||||
|
/// \param[in] central_freq Central frequency (in Hz).
|
||||||
|
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
|
||||||
|
/// \param[in] noise Choose alternate mode for un-pitched audio or mode oriented to pitched audio(Default: False).
|
||||||
|
explicit BandBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool noise = false);
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~BandBiquad() = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||||
|
/// \return Shared pointer to TensorOperation object.
|
||||||
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Data;
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
|
};
|
||||||
|
|
||||||
/// \brief Design two-pole band-pass filter.
|
/// \brief Design two-pole band-pass filter.
|
||||||
class MS_API BandpassBiquad final : public TensorTransform {
|
class MS_API BandpassBiquad final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
|
@ -561,7 +559,7 @@ class MS_API LFilter final : public TensorTransform {
|
||||||
/// Lower delays coefficients are first, e.g. [b0, b1, b2, ...].
|
/// Lower delays coefficients are first, e.g. [b0, b1, b2, ...].
|
||||||
/// Must be same size as a_coeffs (pad with 0's as necessary).
|
/// Must be same size as a_coeffs (pad with 0's as necessary).
|
||||||
/// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1] (Default: True).
|
/// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1] (Default: True).
|
||||||
explicit LFilter(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp = true);
|
explicit LFilter(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp = true);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~LFilter() = default;
|
~LFilter() = default;
|
||||||
|
@ -695,8 +693,8 @@ class MS_API Phaser final : public TensorTransform {
|
||||||
/// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2] (Default=0.5).
|
/// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2] (Default=0.5).
|
||||||
/// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments).
|
/// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments).
|
||||||
/// If false, use triangular modulation (gives single instruments a sharper phasing effect) (Default=true).
|
/// If false, use triangular modulation (gives single instruments a sharper phasing effect) (Default=true).
|
||||||
Phaser(int32_t sample_rate, float gain_in = 0.4f, float gain_out = 0.74f, float delay_ms = 3.0f, float decay = 0.4f,
|
explicit Phaser(int32_t sample_rate, float gain_in = 0.4f, float gain_out = 0.74f, float delay_ms = 3.0f,
|
||||||
float mod_speed = 0.5f, bool sinusoidal = true);
|
float decay = 0.4f, float mod_speed = 0.5f, bool sinusoidal = true);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~Phaser() = default;
|
~Phaser() = default;
|
||||||
|
@ -770,7 +768,7 @@ class MS_API SpectralCentroid : public TensorTransform {
|
||||||
/// \param[in] window Window function that is applied/multiplied to each frame/window,
|
/// \param[in] window Window function that is applied/multiplied to each frame/window,
|
||||||
/// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming,
|
/// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming,
|
||||||
/// WindowType::kHann or WindowType::kKaiser (Default: WindowType::kHann).
|
/// WindowType::kHann or WindowType::kKaiser (Default: WindowType::kHann).
|
||||||
SpectralCentroid(int sample_rate, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0,
|
explicit SpectralCentroid(int32_t sample_rate, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0,
|
||||||
int32_t pad = 0, WindowType window = WindowType::kHann);
|
int32_t pad = 0, WindowType window = WindowType::kHann);
|
||||||
|
|
||||||
~SpectralCentroid() = default;
|
~SpectralCentroid() = default;
|
||||||
|
@ -807,9 +805,9 @@ class MS_API Spectrogram : public TensorTransform {
|
||||||
/// \param[in] center Whether to pad waveform on both sides (Default: true).
|
/// \param[in] center Whether to pad waveform on both sides (Default: true).
|
||||||
/// \param[in] pad_mode Controls the padding method used when center is true (Default: BorderType::kReflect).
|
/// \param[in] pad_mode Controls the padding method used when center is true (Default: BorderType::kReflect).
|
||||||
/// \param[in] onesided Controls whether to return half of results to avoid redundancy (Default: true).
|
/// \param[in] onesided Controls whether to return half of results to avoid redundancy (Default: true).
|
||||||
Spectrogram(int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0,
|
explicit Spectrogram(int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0,
|
||||||
WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, bool center = true,
|
WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false,
|
||||||
BorderType pad_mode = BorderType::kReflect, bool onesided = true);
|
bool center = true, BorderType pad_mode = BorderType::kReflect, bool onesided = true);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~Spectrogram() = default;
|
~Spectrogram() = default;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -20,14 +20,14 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Config operations for setting and getting the configuration.
|
// Config operations for setting and getting the configuration.
|
||||||
namespace config {
|
namespace config {
|
||||||
|
|
||||||
/// \brief A function to set the seed to be used in any random generator. This is used to produce deterministic results.
|
/// \brief A function to set the seed to be used in any random generator. This is used to produce deterministic results.
|
||||||
/// \param[in] seed The default seed to be used.
|
/// \param[in] seed The default seed to be used.
|
||||||
/// \return The seed is set successfully or not.
|
/// \return The seed is set successfully or not.
|
||||||
|
@ -155,10 +155,8 @@ bool MS_API load(const std::vector<char> &file);
|
||||||
/// std::string config_file = "/path/to/config/file";
|
/// std::string config_file = "/path/to/config/file";
|
||||||
/// bool rc = config::load(config_file);
|
/// bool rc = config::load(config_file);
|
||||||
/// \endcode
|
/// \endcode
|
||||||
inline bool MS_API load(std::string file) { return load(StringToChar(file)); }
|
inline bool MS_API load(const std::string &file) { return load(StringToChar(file)); }
|
||||||
|
|
||||||
} // namespace config
|
} // namespace config
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_CONFIG_H
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_CONFIG_H
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -36,7 +36,7 @@ namespace dataset {
|
||||||
class MS_API DataHelper {
|
class MS_API DataHelper {
|
||||||
public:
|
public:
|
||||||
/// \brief constructor
|
/// \brief constructor
|
||||||
DataHelper() {}
|
DataHelper() = default;
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~DataHelper() = default;
|
~DataHelper() = default;
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -36,22 +36,22 @@ class MS_API Execute {
|
||||||
/// \param[in] op TensorOperation to be applied in Eager mode, it accepts operation in type of shared pointer.
|
/// \param[in] op TensorOperation to be applied in Eager mode, it accepts operation in type of shared pointer.
|
||||||
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
||||||
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
||||||
explicit Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice device_type = MapTargetDevice::kCpu,
|
explicit Execute(const std::shared_ptr<TensorOperation> &op, MapTargetDevice device_type = MapTargetDevice::kCpu,
|
||||||
uint32_t device_id = 0);
|
uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of shared pointer.
|
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of shared pointer.
|
||||||
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
||||||
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
||||||
explicit Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice device_type = MapTargetDevice::kCpu,
|
explicit Execute(const std::shared_ptr<TensorTransform> &op, MapTargetDevice device_type = MapTargetDevice::kCpu,
|
||||||
uint32_t device_id = 0);
|
uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of reference.
|
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of reference.
|
||||||
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
||||||
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
||||||
explicit Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice device_type = MapTargetDevice::kCpu,
|
explicit Execute(const std::reference_wrapper<TensorTransform> &op,
|
||||||
uint32_t device_id = 0);
|
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of raw pointer.
|
/// \param[in] op TensorTransform to be applied in Eager mode, it accepts operation in type of raw pointer.
|
||||||
|
@ -64,7 +64,7 @@ class MS_API Execute {
|
||||||
/// in type of shared pointer.
|
/// in type of shared pointer.
|
||||||
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
||||||
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
||||||
explicit Execute(std::vector<std::shared_ptr<TensorOperation>> ops,
|
explicit Execute(const std::vector<std::shared_ptr<TensorOperation>> &ops,
|
||||||
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
|
@ -72,7 +72,7 @@ class MS_API Execute {
|
||||||
/// in type of shared pointer.
|
/// in type of shared pointer.
|
||||||
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
||||||
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
||||||
explicit Execute(std::vector<std::shared_ptr<TensorTransform>> ops,
|
explicit Execute(const std::vector<std::shared_ptr<TensorTransform>> &ops,
|
||||||
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
|
@ -80,7 +80,7 @@ class MS_API Execute {
|
||||||
/// in type of raw pointer.
|
/// in type of raw pointer.
|
||||||
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
/// \param[in] device_type Target device environment to perform operation, can be kCPU/kGPU/kAscend310 (default=kCPU).
|
||||||
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
/// \param[in] device_id Target device ID to perform operation, only valid when device_type=kAscend310 (default=0).
|
||||||
explicit Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
explicit Execute(const std::vector<std::reference_wrapper<TensorTransform>> &ops,
|
||||||
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
MapTargetDevice device_type = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,13 +22,13 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Forward declare
|
// Forward declare
|
||||||
class ExecutionTree;
|
class ExecutionTree;
|
||||||
class DatasetOp;
|
class DatasetOp;
|
||||||
|
@ -57,7 +57,7 @@ class MS_API Iterator {
|
||||||
/// \param[in] ds The last DatasetOp in the dataset pipeline.
|
/// \param[in] ds The last DatasetOp in the dataset pipeline.
|
||||||
/// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs).
|
/// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs).
|
||||||
/// \return Status error code, returns OK if no error encountered.
|
/// \return Status error code, returns OK if no error encountered.
|
||||||
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);
|
Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs);
|
||||||
|
|
||||||
/// \brief Function to get the next row from the data pipeline.
|
/// \brief Function to get the next row from the data pipeline.
|
||||||
/// \note Type of return data is a unordered_map(with column name).
|
/// \note Type of return data is a unordered_map(with column name).
|
||||||
|
@ -185,7 +185,7 @@ class MS_API PullIterator : public Iterator {
|
||||||
/// \note Consider making this function protected.
|
/// \note Consider making this function protected.
|
||||||
/// \param[in] ds The root node that calls the function.
|
/// \param[in] ds The root node that calls the function.
|
||||||
/// \return Status error code, returns OK if no error encountered.
|
/// \return Status error code, returns OK if no error encountered.
|
||||||
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
|
Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
|
std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -24,7 +24,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Forward declare
|
// Forward declare
|
||||||
class SamplerObj;
|
class SamplerObj;
|
||||||
|
|
||||||
|
@ -72,14 +71,14 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
Sampler() {}
|
Sampler() = default;
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~Sampler() = default;
|
~Sampler() = default;
|
||||||
|
|
||||||
/// \brief A virtual function to add a child sampler.
|
/// \brief A virtual function to add a child sampler.
|
||||||
/// \param[in] child The child sampler to be added as a children of this sampler.
|
/// \param[in] child The child sampler to be added as a children of this sampler.
|
||||||
virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); }
|
virtual void AddChild(const std::shared_ptr<Sampler> &child) { children_.push_back(child); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
|
/// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
|
||||||
|
@ -238,7 +237,7 @@ class MS_API SubsetSampler : public Sampler {
|
||||||
/// std::string folder_path = "/path/to/image/folder";
|
/// std::string folder_path = "/path/to/image/folder";
|
||||||
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetSampler>({0, 2, 5}));
|
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetSampler>({0, 2, 5}));
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
|
explicit SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~SubsetSampler() = default;
|
~SubsetSampler() = default;
|
||||||
|
@ -267,7 +266,7 @@ class MS_API SubsetRandomSampler final : public SubsetSampler {
|
||||||
/// std::string folder_path = "/path/to/image/folder";
|
/// std::string folder_path = "/path/to/image/folder";
|
||||||
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>({2, 7}));
|
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>({2, 7}));
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
|
explicit SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~SubsetRandomSampler() = default;
|
~SubsetRandomSampler() = default;
|
||||||
|
@ -297,7 +296,7 @@ class MS_API WeightedRandomSampler final : public Sampler {
|
||||||
/// std::string folder_path = "/path/to/image/folder";
|
/// std::string folder_path = "/path/to/image/folder";
|
||||||
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
|
/// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
|
explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~WeightedRandomSampler() = default;
|
~WeightedRandomSampler() = default;
|
||||||
|
@ -312,7 +311,6 @@ class MS_API WeightedRandomSampler final : public Sampler {
|
||||||
int64_t num_samples_;
|
int64_t num_samples_;
|
||||||
bool replacement_;
|
bool replacement_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -30,7 +30,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
class SentencePieceVocab;
|
class SentencePieceVocab;
|
||||||
class TensorOperation;
|
class TensorOperation;
|
||||||
class Vectors;
|
class Vectors;
|
||||||
|
@ -63,7 +62,7 @@ class MS_API BasicTokenizer final : public TensorTransform {
|
||||||
/// {"text"}); // input columns
|
/// {"text"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit BasicTokenizer(bool lower_case = false, bool keep_whitespace = false,
|
explicit BasicTokenizer(bool lower_case = false, bool keep_whitespace = false,
|
||||||
const NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
|
NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
|
||||||
bool with_offsets = false);
|
bool with_offsets = false);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
|
@ -136,8 +135,7 @@ class MS_API BertTokenizer final : public TensorTransform {
|
||||||
/// \param[in] with_offsets Whether to output offsets of tokens (default=false).
|
/// \param[in] with_offsets Whether to output offsets of tokens (default=false).
|
||||||
BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
|
BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
|
||||||
int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool lower_case,
|
int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool lower_case,
|
||||||
bool keep_whitespace, const NormalizeForm normalize_form, bool preserve_unused_token,
|
bool keep_whitespace, NormalizeForm normalize_form, bool preserve_unused_token, bool with_offsets);
|
||||||
bool with_offsets);
|
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~BertTokenizer() = default;
|
~BertTokenizer() = default;
|
||||||
|
@ -448,7 +446,7 @@ class MS_API RegexReplace final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({regex_op}, // operations
|
/// dataset = dataset->Map({regex_op}, // operations
|
||||||
/// {"text"}); // input columns
|
/// {"text"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
RegexReplace(std::string pattern, std::string replace, bool replace_all = true)
|
RegexReplace(const std::string &pattern, const std::string &replace, bool replace_all = true)
|
||||||
: RegexReplace(StringToChar(pattern), StringToChar(replace), replace_all) {}
|
: RegexReplace(StringToChar(pattern), StringToChar(replace), replace_all) {}
|
||||||
|
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
|
@ -489,7 +487,8 @@ class MS_API RegexTokenizer final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({regex_op}, // operations
|
/// dataset = dataset->Map({regex_op}, // operations
|
||||||
/// {"text"}); // input columns
|
/// {"text"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RegexTokenizer(std::string delim_pattern, std::string keep_delim_pattern = "", bool with_offsets = false)
|
explicit RegexTokenizer(const std::string &delim_pattern, const std::string &keep_delim_pattern = "",
|
||||||
|
bool with_offsets = false)
|
||||||
: RegexTokenizer(StringToChar(delim_pattern), StringToChar(keep_delim_pattern), with_offsets) {}
|
: RegexTokenizer(StringToChar(delim_pattern), StringToChar(keep_delim_pattern), with_offsets) {}
|
||||||
|
|
||||||
explicit RegexTokenizer(const std::vector<char> &delim_pattern, const std::vector<char> &keep_delim_pattern,
|
explicit RegexTokenizer(const std::vector<char> &delim_pattern, const std::vector<char> &keep_delim_pattern,
|
||||||
|
@ -581,7 +580,7 @@ class MS_API SlidingWindow final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({slidingwindow_op}, // operations
|
/// dataset = dataset->Map({slidingwindow_op}, // operations
|
||||||
/// {"text"}); // input columns
|
/// {"text"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit SlidingWindow(const int32_t width, const int32_t axis = 0);
|
explicit SlidingWindow(int32_t width, int32_t axis = 0);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~SlidingWindow() = default;
|
~SlidingWindow() = default;
|
||||||
|
@ -637,7 +636,7 @@ class MS_API ToVectors final : public TensorTransform {
|
||||||
/// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`.
|
/// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`.
|
||||||
/// (default={}, means to initialize with zero vectors).
|
/// (default={}, means to initialize with zero vectors).
|
||||||
/// \param[in] lower_case_backup Whether to look up the token in the lower case (default=false).
|
/// \param[in] lower_case_backup Whether to look up the token in the lower case (default=false).
|
||||||
explicit ToVectors(const std::shared_ptr<Vectors> &vectors, std::vector<float> unk_init = {},
|
explicit ToVectors(const std::shared_ptr<Vectors> &vectors, const std::vector<float> &unk_init = {},
|
||||||
bool lower_case_backup = false);
|
bool lower_case_backup = false);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -29,7 +29,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
class TensorOperation;
|
class TensorOperation;
|
||||||
|
|
||||||
// We need the following two groups of forward declaration to friend the class in class TensorTransform.
|
// We need the following two groups of forward declaration to friend the class in class TensorTransform.
|
||||||
|
@ -60,7 +59,7 @@ class MS_API TensorTransform : public std::enable_shared_from_this<TensorTransfo
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
TensorTransform() {}
|
TensorTransform() = default;
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~TensorTransform() = default;
|
~TensorTransform() = default;
|
||||||
|
@ -108,10 +107,13 @@ class MS_API SliceOption {
|
||||||
public:
|
public:
|
||||||
/// \param[in] all Slice the whole dimension
|
/// \param[in] all Slice the whole dimension
|
||||||
explicit SliceOption(bool all) : all_(all) {}
|
explicit SliceOption(bool all) : all_(all) {}
|
||||||
|
|
||||||
/// \param[in] indices Slice these indices along the dimension. Negative indices are supported.
|
/// \param[in] indices Slice these indices along the dimension. Negative indices are supported.
|
||||||
explicit SliceOption(std::vector<dsize_t> indices) : indices_(indices) {}
|
explicit SliceOption(const std::vector<dsize_t> &indices) : indices_(indices) {}
|
||||||
|
|
||||||
/// \param[in] slice Slice the generated indices from the slice object along the dimension.
|
/// \param[in] slice Slice the generated indices from the slice object along the dimension.
|
||||||
explicit SliceOption(Slice slice) : slice_(slice) {}
|
explicit SliceOption(Slice slice) : slice_(slice) {}
|
||||||
|
|
||||||
SliceOption(SliceOption const &slice) = default;
|
SliceOption(SliceOption const &slice) = default;
|
||||||
|
|
||||||
~SliceOption() = default;
|
~SliceOption() = default;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -31,12 +31,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
class TensorOperation;
|
class TensorOperation;
|
||||||
|
|
||||||
// Transform operations for performing computer vision.
|
// Transform operations for performing computer vision.
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
/// \brief AdjustGamma TensorTransform.
|
/// \brief AdjustGamma TensorTransform.
|
||||||
/// \note Apply gamma correction on input image.
|
/// \note Apply gamma correction on input image.
|
||||||
class MS_API AdjustGamma final : public TensorTransform {
|
class MS_API AdjustGamma final : public TensorTransform {
|
||||||
|
@ -49,10 +47,10 @@ class MS_API AdjustGamma final : public TensorTransform {
|
||||||
/// \code
|
/// \code
|
||||||
/// /* Define operations */
|
/// /* Define operations */
|
||||||
/// auto decode_op = vision::Decode();
|
/// auto decode_op = vision::Decode();
|
||||||
/// auto adjustgamma_op = vision::AdjustGamma(10.0);
|
/// auto adjust_gamma_op = vision::AdjustGamma(10.0);
|
||||||
///
|
///
|
||||||
/// /* dataset is an instance of Dataset object */
|
/// /* dataset is an instance of Dataset object */
|
||||||
/// dataset = dataset->Map({decode_op, adjustgamma_op}, // operations
|
/// dataset = dataset->Map({decode_op, adjust_gamma_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit AdjustGamma(float gamma, float gain = 1);
|
explicit AdjustGamma(float gamma, float gain = 1);
|
||||||
|
@ -93,7 +91,7 @@ class MS_API AutoAugment final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, auto_augment_op}, // operations
|
/// dataset = dataset->Map({decode_op, auto_augment_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
AutoAugment(AutoAugmentPolicy policy = AutoAugmentPolicy::kImageNet,
|
explicit AutoAugment(AutoAugmentPolicy policy = AutoAugmentPolicy::kImageNet,
|
||||||
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
||||||
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||||
|
|
||||||
|
@ -126,7 +124,7 @@ class MS_API AutoContrast final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, autocontrast_op}, // operations
|
/// dataset = dataset->Map({decode_op, autocontrast_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit AutoContrast(float cutoff = 0.0, std::vector<uint32_t> ignore = {});
|
explicit AutoContrast(float cutoff = 0.0, const std::vector<uint32_t> &ignore = {});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~AutoContrast() = default;
|
~AutoContrast() = default;
|
||||||
|
@ -188,7 +186,7 @@ class MS_API BoundingBoxAugment final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({bbox_aug_op}, // operations
|
/// dataset = dataset->Map({bbox_aug_op}, // operations
|
||||||
/// {"image", "bbox"}); // input columns
|
/// {"image", "bbox"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit BoundingBoxAugment(const std::reference_wrapper<TensorTransform> transform, float ratio = 0.3);
|
explicit BoundingBoxAugment(const std::reference_wrapper<TensorTransform> &transform, float ratio = 0.3);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~BoundingBoxAugment() = default;
|
~BoundingBoxAugment() = default;
|
||||||
|
@ -473,7 +471,7 @@ class MS_API Pad final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, pad_op}, // operations
|
/// dataset = dataset->Map({decode_op, pad_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
|
explicit Pad(const std::vector<int32_t> &padding, const std::vector<uint8_t> &fill_value = {0},
|
||||||
BorderType padding_mode = BorderType::kConstant);
|
BorderType padding_mode = BorderType::kConstant);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
|
@ -509,7 +507,7 @@ class MS_API RandomAutoContrast final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_auto_contrast_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_auto_contrast_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomAutoContrast(float cutoff = 0.0, std::vector<uint32_t> ignore = {}, float prob = 0.5);
|
explicit RandomAutoContrast(float cutoff = 0.0, const std::vector<uint32_t> &ignore = {}, float prob = 0.5);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomAutoContrast() = default;
|
~RandomAutoContrast() = default;
|
||||||
|
@ -612,8 +610,10 @@ class MS_API RandomColorAdjust final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_color_adjust_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_color_adjust_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomColorAdjust(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},
|
explicit RandomColorAdjust(const std::vector<float> &brightness = {1.0, 1.0},
|
||||||
std::vector<float> saturation = {1.0, 1.0}, std::vector<float> hue = {0.0, 0.0});
|
const std::vector<float> &contrast = {1.0, 1.0},
|
||||||
|
const std::vector<float> &saturation = {1.0, 1.0},
|
||||||
|
const std::vector<float> &hue = {0.0, 0.0});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomColorAdjust() = default;
|
~RandomColorAdjust() = default;
|
||||||
|
@ -663,8 +663,8 @@ class MS_API RandomCrop final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_crop_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_crop_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
|
explicit RandomCrop(const std::vector<int32_t> &size, const std::vector<int32_t> &padding = {0, 0, 0, 0},
|
||||||
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0},
|
bool pad_if_needed = false, const std::vector<uint8_t> &fill_value = {0, 0, 0},
|
||||||
BorderType padding_mode = BorderType::kConstant);
|
BorderType padding_mode = BorderType::kConstant);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
|
@ -708,8 +708,8 @@ class MS_API RandomCropDecodeResize final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({random_op}, // operations
|
/// dataset = dataset->Map({random_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomCropDecodeResize(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
explicit RandomCropDecodeResize(const std::vector<int32_t> &size, const std::vector<float> &scale = {0.08, 1.0},
|
||||||
std::vector<float> ratio = {3. / 4, 4. / 3},
|
const std::vector<float> &ratio = {3. / 4, 4. / 3},
|
||||||
InterpolationMode interpolation = InterpolationMode::kLinear,
|
InterpolationMode interpolation = InterpolationMode::kLinear,
|
||||||
int32_t max_attempts = 10);
|
int32_t max_attempts = 10);
|
||||||
|
|
||||||
|
@ -760,8 +760,8 @@ class MS_API RandomCropWithBBox final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({random_op}, // operations
|
/// dataset = dataset->Map({random_op}, // operations
|
||||||
/// {"image", "bbox"}); // input columns
|
/// {"image", "bbox"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomCropWithBBox(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
|
explicit RandomCropWithBBox(const std::vector<int32_t> &size, const std::vector<int32_t> &padding = {0, 0, 0, 0},
|
||||||
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0},
|
bool pad_if_needed = false, const std::vector<uint8_t> &fill_value = {0, 0, 0},
|
||||||
BorderType padding_mode = BorderType::kConstant);
|
BorderType padding_mode = BorderType::kConstant);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
|
@ -976,7 +976,7 @@ class MS_API RandomResize final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomResize(std::vector<int32_t> size);
|
explicit RandomResize(const std::vector<int32_t> &size);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomResize() = default;
|
~RandomResize() = default;
|
||||||
|
@ -1008,7 +1008,7 @@ class MS_API RandomResizeWithBBox final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({random_op}, // operations
|
/// dataset = dataset->Map({random_op}, // operations
|
||||||
/// {"image", "bbox"}); // input columns
|
/// {"image", "bbox"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomResizeWithBBox(std::vector<int32_t> size);
|
explicit RandomResizeWithBBox(const std::vector<int32_t> &size);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomResizeWithBBox() = default;
|
~RandomResizeWithBBox() = default;
|
||||||
|
@ -1053,8 +1053,8 @@ class MS_API RandomResizedCrop final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomResizedCrop(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
explicit RandomResizedCrop(const std::vector<int32_t> &size, const std::vector<float> &scale = {0.08, 1.0},
|
||||||
std::vector<float> ratio = {3. / 4., 4. / 3.},
|
const std::vector<float> &ratio = {3. / 4., 4. / 3.},
|
||||||
InterpolationMode interpolation = InterpolationMode::kLinear, int32_t max_attempts = 10);
|
InterpolationMode interpolation = InterpolationMode::kLinear, int32_t max_attempts = 10);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
|
@ -1100,9 +1100,10 @@ class MS_API RandomResizedCropWithBBox final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({random_op}, // operations
|
/// dataset = dataset->Map({random_op}, // operations
|
||||||
/// {"image", "bbox"}); // input columns
|
/// {"image", "bbox"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
RandomResizedCropWithBBox(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
explicit RandomResizedCropWithBBox(const std::vector<int32_t> &size, const std::vector<float> &scale = {0.08, 1.0},
|
||||||
std::vector<float> ratio = {3. / 4., 4. / 3.},
|
const std::vector<float> &ratio = {3. / 4., 4. / 3.},
|
||||||
InterpolationMode interpolation = InterpolationMode::kLinear, int32_t max_attempts = 10);
|
InterpolationMode interpolation = InterpolationMode::kLinear,
|
||||||
|
int32_t max_attempts = 10);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomResizedCropWithBBox() = default;
|
~RandomResizedCropWithBBox() = default;
|
||||||
|
@ -1144,8 +1145,9 @@ class MS_API RandomRotation final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
RandomRotation(std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour,
|
explicit RandomRotation(const std::vector<float> °rees,
|
||||||
bool expand = false, std::vector<float> center = {}, std::vector<uint8_t> fill_value = {0, 0, 0});
|
InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
|
||||||
|
const std::vector<float> ¢er = {}, const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomRotation() = default;
|
~RandomRotation() = default;
|
||||||
|
@ -1255,7 +1257,7 @@ class MS_API RandomSharpness final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomSharpness(std::vector<float> degrees = {0.1, 1.9});
|
explicit RandomSharpness(const std::vector<float> °rees = {0.1, 1.9});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomSharpness() = default;
|
~RandomSharpness() = default;
|
||||||
|
@ -1287,7 +1289,7 @@ class MS_API RandomSolarize final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
/// dataset = dataset->Map({decode_op, random_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit RandomSolarize(std::vector<uint8_t> threshold = {0, 255});
|
explicit RandomSolarize(const std::vector<uint8_t> &threshold = {0, 255});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~RandomSolarize() = default;
|
~RandomSolarize() = default;
|
||||||
|
@ -1414,7 +1416,8 @@ class MS_API ResizeWithBBox final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({random_op}, // operations
|
/// dataset = dataset->Map({random_op}, // operations
|
||||||
/// {"image", "bbox"}); // input columns
|
/// {"image", "bbox"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit ResizeWithBBox(std::vector<int32_t> size, InterpolationMode interpolation = InterpolationMode::kLinear);
|
explicit ResizeWithBBox(const std::vector<int32_t> &size,
|
||||||
|
InterpolationMode interpolation = InterpolationMode::kLinear);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~ResizeWithBBox() = default;
|
~ResizeWithBBox() = default;
|
||||||
|
@ -1500,7 +1503,7 @@ class MS_API SlicePatches final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, slice_patch_op}, // operations
|
/// dataset = dataset->Map({decode_op, slice_patch_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
SlicePatches(int32_t num_height = 1, int32_t num_width = 1, SliceMode slice_mode = SliceMode::kPad,
|
explicit SlicePatches(int32_t num_height = 1, int32_t num_width = 1, SliceMode slice_mode = SliceMode::kPad,
|
||||||
uint8_t fill_value = 0);
|
uint8_t fill_value = 0);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
|
@ -1542,8 +1545,10 @@ class MS_API SoftDvppDecodeRandomCropResizeJpeg final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({dvpp_op}, // operations
|
/// dataset = dataset->Map({dvpp_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
explicit SoftDvppDecodeRandomCropResizeJpeg(const std::vector<int32_t> &size,
|
||||||
std::vector<float> ratio = {3. / 4., 4. / 3.}, int32_t max_attempts = 10);
|
const std::vector<float> &scale = {0.08, 1.0},
|
||||||
|
const std::vector<float> &ratio = {3. / 4., 4. / 3.},
|
||||||
|
int32_t max_attempts = 10);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~SoftDvppDecodeRandomCropResizeJpeg() = default;
|
~SoftDvppDecodeRandomCropResizeJpeg() = default;
|
||||||
|
@ -1581,7 +1586,7 @@ class MS_API SoftDvppDecodeResizeJpeg final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({dvpp_op}, // operations
|
/// dataset = dataset->Map({dvpp_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit SoftDvppDecodeResizeJpeg(std::vector<int32_t> size);
|
explicit SoftDvppDecodeResizeJpeg(const std::vector<int32_t> &size);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~SoftDvppDecodeResizeJpeg() = default;
|
~SoftDvppDecodeResizeJpeg() = default;
|
||||||
|
@ -1712,7 +1717,6 @@ class MS_API VerticalFlip final : public TensorTransform {
|
||||||
/// \return Shared pointer to TensorOperation object.
|
/// \return Shared pointer to TensorOperation object.
|
||||||
std::shared_ptr<TensorOperation> Parse() override;
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,16 +22,15 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
#include "include/dataset/constants.h"
|
#include "include/dataset/constants.h"
|
||||||
#include "include/dataset/transforms.h"
|
#include "include/dataset/transforms.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Transform operations for performing computer vision.
|
// Transform operations for performing computer vision.
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
/* ##################################### API class ###########################################*/
|
/* ##################################### API class ###########################################*/
|
||||||
|
|
||||||
/// \brief Decode and resize JPEG image using the hardware algorithm of
|
/// \brief Decode and resize JPEG image using the hardware algorithm of
|
||||||
|
@ -49,7 +48,7 @@ class MS_API DvppDecodeResizeJpeg final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({dvpp_op}, // operations
|
/// dataset = dataset->Map({dvpp_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit DvppDecodeResizeJpeg(std::vector<uint32_t> resize);
|
explicit DvppDecodeResizeJpeg(const std::vector<uint32_t> &resize);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~DvppDecodeResizeJpeg() = default;
|
~DvppDecodeResizeJpeg() = default;
|
||||||
|
@ -82,7 +81,7 @@ class MS_API DvppDecodeResizeCropJpeg final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({dvpp_op}, // operations
|
/// dataset = dataset->Map({dvpp_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
DvppDecodeResizeCropJpeg(std::vector<uint32_t> crop, std::vector<uint32_t> resize);
|
DvppDecodeResizeCropJpeg(const std::vector<uint32_t> &crop, const std::vector<uint32_t> &resize);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~DvppDecodeResizeCropJpeg() = default;
|
~DvppDecodeResizeCropJpeg() = default;
|
||||||
|
@ -125,7 +124,6 @@ class MS_API DvppDecodePng final : public TensorTransform {
|
||||||
|
|
||||||
std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) override;
|
std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,16 +22,15 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
#include "include/dataset/constants.h"
|
#include "include/dataset/constants.h"
|
||||||
#include "include/dataset/transforms.h"
|
#include "include/dataset/transforms.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
// Transform operations for performing computer vision.
|
// Transform operations for performing computer vision.
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
// Forward Declarations
|
// Forward Declarations
|
||||||
class RotateOperation;
|
class RotateOperation;
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ class MS_API CenterCrop final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, crop_op}, // operations
|
/// dataset = dataset->Map({decode_op, crop_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit CenterCrop(std::vector<int32_t> size);
|
explicit CenterCrop(const std::vector<int32_t> &size);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~CenterCrop() = default;
|
~CenterCrop() = default;
|
||||||
|
@ -131,7 +130,7 @@ class MS_API Crop final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, crop_op}, // operations
|
/// dataset = dataset->Map({decode_op, crop_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size);
|
Crop(const std::vector<int32_t> &coordinates, const std::vector<int32_t> &size);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~Crop() = default;
|
~Crop() = default;
|
||||||
|
@ -194,7 +193,7 @@ class MS_API GaussianBlur final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, gaussian_op}, // operations
|
/// dataset = dataset->Map({decode_op, gaussian_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma = {0., 0.});
|
explicit GaussianBlur(const std::vector<int32_t> &kernel_size, const std::vector<float> &sigma = {0., 0.});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~GaussianBlur() = default;
|
~GaussianBlur() = default;
|
||||||
|
@ -227,7 +226,7 @@ class MS_API Normalize final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, normalize_op}, // operations
|
/// dataset = dataset->Map({decode_op, normalize_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
Normalize(std::vector<float> mean, std::vector<float> std);
|
Normalize(const std::vector<float> &mean, const std::vector<float> &std);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~Normalize() = default;
|
~Normalize() = default;
|
||||||
|
@ -318,7 +317,7 @@ class MS_API Resize final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, resize_op}, // operations
|
/// dataset = dataset->Map({decode_op, resize_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
explicit Resize(std::vector<int32_t> size, InterpolationMode interpolation = InterpolationMode::kLinear);
|
explicit Resize(const std::vector<int32_t> &size, InterpolationMode interpolation = InterpolationMode::kLinear);
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~Resize() = default;
|
~Resize() = default;
|
||||||
|
@ -477,8 +476,8 @@ class MS_API Rotate final : public TensorTransform {
|
||||||
/// dataset = dataset->Map({decode_op, rotate_op}, // operations
|
/// dataset = dataset->Map({decode_op, rotate_op}, // operations
|
||||||
/// {"image"}); // input columns
|
/// {"image"}); // input columns
|
||||||
/// \endcode
|
/// \endcode
|
||||||
Rotate(float degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
|
explicit Rotate(float degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
|
||||||
std::vector<float> center = {}, std::vector<uint8_t> fill_value = {0, 0, 0});
|
const std::vector<float> ¢er = {}, const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||||
|
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~Rotate() = default;
|
~Rotate() = default;
|
||||||
|
@ -493,7 +492,6 @@ class MS_API Rotate final : public TensorTransform {
|
||||||
struct Data;
|
struct Data;
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
|
||||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
|
||||||
|
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
@ -38,7 +38,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
class Tensor;
|
class Tensor;
|
||||||
class TensorShape;
|
class TensorShape;
|
||||||
class TreeAdapter;
|
class TreeAdapter;
|
||||||
|
@ -132,7 +131,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
/// iter->GetNextRow(&row);
|
/// iter->GetNextRow(&row);
|
||||||
/// \endcode
|
/// \endcode
|
||||||
std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {});
|
std::shared_ptr<PullIterator> CreatePullBasedIterator(const std::vector<std::vector<char>> &columns = {});
|
||||||
|
|
||||||
/// \brief Function to create an Iterator over the Dataset pipeline
|
/// \brief Function to create an Iterator over the Dataset pipeline
|
||||||
/// \param[in] columns List of columns to be used to specify the order of columns
|
/// \param[in] columns List of columns to be used to specify the order of columns
|
||||||
|
@ -146,7 +145,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
/// iter->GetNextRow(&row);
|
/// iter->GetNextRow(&row);
|
||||||
/// \endcode
|
/// \endcode
|
||||||
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) {
|
std::shared_ptr<Iterator> CreateIterator(const std::vector<std::string> &columns = {}, int32_t num_epochs = -1) {
|
||||||
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
|
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +161,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
|
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
|
||||||
/// of data or not(default=false).
|
/// of data or not(default=false).
|
||||||
/// \return Returns true if no error encountered else false.
|
/// \return Returns true if no error encountered else false.
|
||||||
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
|
bool DeviceQueue(const std::string &queue_name = "", const std::string &device_type = "", int32_t device_id = 0,
|
||||||
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
|
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
|
||||||
bool create_data_info_queue = false) {
|
bool create_data_info_queue = false) {
|
||||||
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
|
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
|
||||||
|
@ -189,7 +188,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// std::string save_file = "Cifar10Data.mindrecord";
|
/// std::string save_file = "Cifar10Data.mindrecord";
|
||||||
/// bool rc = ds->Save(save_file);
|
/// bool rc = ds->Save(save_file);
|
||||||
/// \endcode
|
/// \endcode
|
||||||
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
|
bool Save(const std::string &dataset_path, int32_t num_files = 1, const std::string &dataset_type = "mindrecord") {
|
||||||
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
|
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,12 +258,12 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// // columns will remain the same.
|
/// // columns will remain the same.
|
||||||
/// dataset = dataset->Map({decode_op, random_jitter_op}, {"image"})
|
/// dataset = dataset->Map({decode_op, random_jitter_op}, {"image"})
|
||||||
/// \endcode
|
/// \endcode
|
||||||
std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations,
|
std::shared_ptr<MapDataset> Map(const std::vector<TensorTransform *> &operations,
|
||||||
const std::vector<std::string> &input_columns = {},
|
const std::vector<std::string> &input_columns = {},
|
||||||
const std::vector<std::string> &output_columns = {},
|
const std::vector<std::string> &output_columns = {},
|
||||||
const std::vector<std::string> &project_columns = {},
|
const std::vector<std::string> &project_columns = {},
|
||||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
|
const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
|
||||||
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
|
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
|
||||||
(void)std::transform(
|
(void)std::transform(
|
||||||
operations.begin(), operations.end(), std::back_inserter(transform_ops),
|
operations.begin(), operations.end(), std::back_inserter(transform_ops),
|
||||||
|
@ -290,15 +289,15 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \param[in] project_columns A list of column names to project
|
/// \param[in] project_columns A list of column names to project
|
||||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||||
/// \return Shared pointer to the current MapDataset
|
/// \return Shared pointer to the current MapDataset
|
||||||
std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
|
std::shared_ptr<MapDataset> Map(const std::vector<std::shared_ptr<TensorTransform>> &operations,
|
||||||
const std::vector<std::string> &input_columns = {},
|
const std::vector<std::string> &input_columns = {},
|
||||||
const std::vector<std::string> &output_columns = {},
|
const std::vector<std::string> &output_columns = {},
|
||||||
const std::vector<std::string> &project_columns = {},
|
const std::vector<std::string> &project_columns = {},
|
||||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
|
const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
|
||||||
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
|
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
|
||||||
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
|
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
|
||||||
[](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
|
[](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
|
||||||
return op != nullptr ? op->Parse() : nullptr;
|
return op != nullptr ? op->Parse() : nullptr;
|
||||||
});
|
});
|
||||||
return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
|
return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
|
||||||
|
@ -322,12 +321,12 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \param[in] project_columns A list of column names to project
|
/// \param[in] project_columns A list of column names to project
|
||||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||||
/// \return Shared pointer to the current MapDataset
|
/// \return Shared pointer to the current MapDataset
|
||||||
std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
|
std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> &operations,
|
||||||
const std::vector<std::string> &input_columns = {},
|
const std::vector<std::string> &input_columns = {},
|
||||||
const std::vector<std::string> &output_columns = {},
|
const std::vector<std::string> &output_columns = {},
|
||||||
const std::vector<std::string> &project_columns = {},
|
const std::vector<std::string> &project_columns = {},
|
||||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
|
const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
|
||||||
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
|
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
|
||||||
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
|
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
|
||||||
[](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
|
[](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
|
||||||
|
@ -378,7 +377,7 @@ class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
|
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
|
||||||
|
|
||||||
// Char interface(CharIF) of CreateIterator
|
// Char interface(CharIF) of CreateIterator
|
||||||
std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
|
std::shared_ptr<Iterator> CreateIteratorCharIF(const std::vector<std::vector<char>> &columns, int32_t num_epochs);
|
||||||
|
|
||||||
// Char interface(CharIF) of DeviceQueue
|
// Char interface(CharIF) of DeviceQueue
|
||||||
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
|
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
|
||||||
|
@ -443,7 +442,7 @@ class MS_API SchemaObj {
|
||||||
std::string to_string() { return to_json(); }
|
std::string to_string() { return to_json(); }
|
||||||
|
|
||||||
/// \brief Set a new value to dataset_type
|
/// \brief Set a new value to dataset_type
|
||||||
void set_dataset_type(std::string dataset_type);
|
void set_dataset_type(const std::string &dataset_type);
|
||||||
|
|
||||||
/// \brief Set a new value to num_rows
|
/// \brief Set a new value to num_rows
|
||||||
void set_num_rows(int32_t num_rows);
|
void set_num_rows(int32_t num_rows);
|
||||||
|
@ -492,28 +491,32 @@ class MS_API SchemaObj {
|
||||||
|
|
||||||
class MS_API BatchDataset : public Dataset {
|
class MS_API BatchDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false);
|
BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder = false);
|
||||||
|
|
||||||
~BatchDataset() = default;
|
~BatchDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MS_API MapDataset : public Dataset {
|
class MS_API MapDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
|
MapDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::shared_ptr<TensorOperation>> &operations,
|
||||||
const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
|
const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
|
||||||
const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
|
const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
|
||||||
std::vector<std::shared_ptr<DSCallback>> callbacks);
|
const std::vector<std::shared_ptr<DSCallback>> &callbacks);
|
||||||
|
|
||||||
~MapDataset() = default;
|
~MapDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MS_API ProjectDataset : public Dataset {
|
class MS_API ProjectDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns);
|
ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns);
|
||||||
|
|
||||||
~ProjectDataset() = default;
|
~ProjectDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MS_API ShuffleDataset : public Dataset {
|
class MS_API ShuffleDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size);
|
ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size);
|
||||||
|
|
||||||
~ShuffleDataset() = default;
|
~ShuffleDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -566,7 +569,7 @@ class MS_API AlbumDataset : public Dataset {
|
||||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||||
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
|
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
|
||||||
const std::vector<std::vector<char>> &column_names, bool decode,
|
const std::vector<std::vector<char>> &column_names, bool decode,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||||
|
|
||||||
/// \brief Destructor of AlbumDataset.
|
/// \brief Destructor of AlbumDataset.
|
||||||
~AlbumDataset() = default;
|
~AlbumDataset() = default;
|
||||||
|
@ -664,7 +667,7 @@ class MS_API MnistDataset : public Dataset {
|
||||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||||
MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
|
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||||
|
|
||||||
/// Destructor of MnistDataset.
|
/// Destructor of MnistDataset.
|
||||||
~MnistDataset() = default;
|
~MnistDataset() = default;
|
||||||
|
@ -726,5 +729,4 @@ inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
||||||
|
|
|
@ -54,7 +54,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate44100) {
|
||||||
std::vector<int64_t> expected = {2, 200};
|
std::vector<int64_t> expected = {2, 200};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["waveform"];
|
auto col = row["waveform"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
@ -93,7 +93,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate48000) {
|
||||||
std::vector<int64_t> expected = {30, 40};
|
std::vector<int64_t> expected = {30, 40};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["waveform"];
|
auto col = row["waveform"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
@ -132,7 +132,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate88200) {
|
||||||
std::vector<int64_t> expected = {5, 4};
|
std::vector<int64_t> expected = {5, 4};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["waveform"];
|
auto col = row["waveform"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
@ -171,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate96000) {
|
||||||
std::vector<int64_t> expected = {2, 3};
|
std::vector<int64_t> expected = {2, 3};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["waveform"];
|
auto col = row["waveform"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
@ -221,7 +221,7 @@ TEST_F(MindDataTestPipeline, TestSlidingWindowCmn) {
|
||||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
i++;
|
i++;
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
}
|
}
|
||||||
|
@ -266,7 +266,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidBasic) {
|
||||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectral_centroid = audio::SpectralCentroid({44100, 8, 8, 4, 1, WindowType::kHann});
|
auto spectral_centroid = audio::SpectralCentroid(44100, 8, 8, 4, 1, WindowType::kHann);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectral_centroid}, {"waveform"});
|
auto ds1 = ds->Map({spectral_centroid}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -278,7 +278,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidBasic) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -297,7 +297,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidDefault) {
|
||||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectral_centroid = audio::SpectralCentroid({44100});
|
auto spectral_centroid = audio::SpectralCentroid(44100);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectral_centroid}, {"waveform"});
|
auto ds1 = ds->Map({spectral_centroid}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -309,7 +309,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidDefault) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -336,7 +336,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
|
||||||
|
|
||||||
// Check n_fft
|
// Check n_fft
|
||||||
MS_LOG(INFO) << "n_fft is zero.";
|
MS_LOG(INFO) << "n_fft is zero.";
|
||||||
auto spectral_centroid_op_1 = audio::SpectralCentroid({44100, 0, 8, 4, 1, WindowType::kHann});
|
auto spectral_centroid_op_1 = audio::SpectralCentroid(44100, 0, 8, 4, 1, WindowType::kHann);
|
||||||
ds01 = ds->Map({spectral_centroid_op_1});
|
ds01 = ds->Map({spectral_centroid_op_1});
|
||||||
EXPECT_NE(ds01, nullptr);
|
EXPECT_NE(ds01, nullptr);
|
||||||
|
|
||||||
|
@ -345,7 +345,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
|
||||||
|
|
||||||
// Check win_length
|
// Check win_length
|
||||||
MS_LOG(INFO) << "win_length is -1.";
|
MS_LOG(INFO) << "win_length is -1.";
|
||||||
auto spectral_centroid_op_2 = audio::SpectralCentroid({44100, 8, -1, 4, 1, WindowType::kHann});
|
auto spectral_centroid_op_2 = audio::SpectralCentroid(44100, 8, -1, 4, 1, WindowType::kHann);
|
||||||
ds02 = ds->Map({spectral_centroid_op_2});
|
ds02 = ds->Map({spectral_centroid_op_2});
|
||||||
EXPECT_NE(ds02, nullptr);
|
EXPECT_NE(ds02, nullptr);
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
|
||||||
|
|
||||||
// Check hop_length
|
// Check hop_length
|
||||||
MS_LOG(INFO) << "hop_length is -1.";
|
MS_LOG(INFO) << "hop_length is -1.";
|
||||||
auto spectral_centroid_op_3 = audio::SpectralCentroid({44100, 8, 8, -1, 1, WindowType::kHann});
|
auto spectral_centroid_op_3 = audio::SpectralCentroid(44100, 8, 8, -1, 1, WindowType::kHann);
|
||||||
ds03 = ds->Map({spectral_centroid_op_3});
|
ds03 = ds->Map({spectral_centroid_op_3});
|
||||||
EXPECT_NE(ds03, nullptr);
|
EXPECT_NE(ds03, nullptr);
|
||||||
|
|
||||||
|
@ -363,7 +363,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
|
||||||
|
|
||||||
// Check pad
|
// Check pad
|
||||||
MS_LOG(INFO) << "pad is -1.";
|
MS_LOG(INFO) << "pad is -1.";
|
||||||
auto spectral_centroid_op_4 = audio::SpectralCentroid({44100, 8, 8, 4, -1, WindowType::kHann});
|
auto spectral_centroid_op_4 = audio::SpectralCentroid(44100, 8, 8, 4, -1, WindowType::kHann);
|
||||||
ds04 = ds->Map({spectral_centroid_op_4});
|
ds04 = ds->Map({spectral_centroid_op_4});
|
||||||
EXPECT_NE(ds04, nullptr);
|
EXPECT_NE(ds04, nullptr);
|
||||||
|
|
||||||
|
@ -372,7 +372,7 @@ TEST_F(MindDataTestPipeline, TestSpectralCentroidWrongArgs) {
|
||||||
|
|
||||||
// Check sample_rate
|
// Check sample_rate
|
||||||
MS_LOG(INFO) << "sample_rate is -1.";
|
MS_LOG(INFO) << "sample_rate is -1.";
|
||||||
auto spectral_centroid_op_5 = audio::SpectralCentroid({-1, 8, 8, 4, 8, WindowType::kHann});
|
auto spectral_centroid_op_5 = audio::SpectralCentroid(-1, 8, 8, 4, 8, WindowType::kHann);
|
||||||
ds05 = ds->Map({spectral_centroid_op_5});
|
ds05 = ds->Map({spectral_centroid_op_5});
|
||||||
EXPECT_NE(ds05, nullptr);
|
EXPECT_NE(ds05, nullptr);
|
||||||
|
|
||||||
|
@ -392,7 +392,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramDefault) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -404,7 +404,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramDefault) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -424,7 +424,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramOnesidedFalse) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, false});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, false);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -436,7 +436,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramOnesidedFalse) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -456,7 +456,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramCenterFalse) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, false, false, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, false, false, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -468,7 +468,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramCenterFalse) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -488,7 +488,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNormalizedTrue) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, 2.0, true, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, 2.0, true, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -500,7 +500,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNormalizedTrue) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -520,7 +520,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWindowHamming) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -532,7 +532,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWindowHamming) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -552,7 +552,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPadmodeEdge) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kEdge, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHamming, 2.0, false, true, BorderType::kEdge, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -564,7 +564,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPadmodeEdge) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -584,7 +584,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPower0) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHamming, 0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHamming, 0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -596,7 +596,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPower0) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -616,7 +616,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNfft50) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({50, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(50, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -628,7 +628,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramNfft50) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -648,7 +648,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPad10) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 20, 10, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 10, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -660,7 +660,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramPad10) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -680,7 +680,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWinlength30) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 30, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 30, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -692,7 +692,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWinlength30) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -712,7 +712,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramHoplength30) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
auto spectrogram =
|
auto spectrogram =
|
||||||
audio::Spectrogram({40, 40, 30, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 30, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
|
|
||||||
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
auto ds1 = ds->Map({spectrogram}, {"waveform"});
|
||||||
EXPECT_NE(ds1, nullptr);
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
@ -724,7 +724,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramHoplength30) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
ASSERT_OK(iter->GetNextRow(&row));
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -753,7 +753,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
|
||||||
// Check n_fft
|
// Check n_fft
|
||||||
MS_LOG(INFO) << "n_fft is zero.";
|
MS_LOG(INFO) << "n_fft is zero.";
|
||||||
auto spectrogram_op_01 =
|
auto spectrogram_op_01 =
|
||||||
audio::Spectrogram({0, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(0, 40, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
ds01 = ds->Map({spectrogram_op_01});
|
ds01 = ds->Map({spectrogram_op_01});
|
||||||
EXPECT_NE(ds01, nullptr);
|
EXPECT_NE(ds01, nullptr);
|
||||||
|
|
||||||
|
@ -763,7 +763,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
|
||||||
// Check win_length
|
// Check win_length
|
||||||
MS_LOG(INFO) << "win_length is -1.";
|
MS_LOG(INFO) << "win_length is -1.";
|
||||||
auto spectrogram_op_02 =
|
auto spectrogram_op_02 =
|
||||||
audio::Spectrogram({40, -1, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, -1, 20, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
ds02 = ds->Map({spectrogram_op_02});
|
ds02 = ds->Map({spectrogram_op_02});
|
||||||
EXPECT_NE(ds02, nullptr);
|
EXPECT_NE(ds02, nullptr);
|
||||||
|
|
||||||
|
@ -773,7 +773,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
|
||||||
// Check hop_length
|
// Check hop_length
|
||||||
MS_LOG(INFO) << "hop_length is -1.";
|
MS_LOG(INFO) << "hop_length is -1.";
|
||||||
auto spectrogram_op_03 =
|
auto spectrogram_op_03 =
|
||||||
audio::Spectrogram({40, 40, -1, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, -1, 0, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
ds03 = ds->Map({spectrogram_op_03});
|
ds03 = ds->Map({spectrogram_op_03});
|
||||||
EXPECT_NE(ds03, nullptr);
|
EXPECT_NE(ds03, nullptr);
|
||||||
|
|
||||||
|
@ -783,7 +783,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
|
||||||
// Check power
|
// Check power
|
||||||
MS_LOG(INFO) << "power is -1.";
|
MS_LOG(INFO) << "power is -1.";
|
||||||
auto spectrogram_op_04 =
|
auto spectrogram_op_04 =
|
||||||
audio::Spectrogram({40, 40, 20, 0, WindowType::kHann, -1, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, 0, WindowType::kHann, -1, false, true, BorderType::kReflect, true);
|
||||||
ds04 = ds->Map({spectrogram_op_04});
|
ds04 = ds->Map({spectrogram_op_04});
|
||||||
EXPECT_NE(ds04, nullptr);
|
EXPECT_NE(ds04, nullptr);
|
||||||
|
|
||||||
|
@ -793,7 +793,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
|
||||||
// Check pad
|
// Check pad
|
||||||
MS_LOG(INFO) << "pad is -1.";
|
MS_LOG(INFO) << "pad is -1.";
|
||||||
auto spectrogram_op_05 =
|
auto spectrogram_op_05 =
|
||||||
audio::Spectrogram({40, 40, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 40, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
ds05 = ds->Map({spectrogram_op_05});
|
ds05 = ds->Map({spectrogram_op_05});
|
||||||
EXPECT_NE(ds05, nullptr);
|
EXPECT_NE(ds05, nullptr);
|
||||||
|
|
||||||
|
@ -803,7 +803,7 @@ TEST_F(MindDataTestPipeline, TestSpectrogramWrongArgs) {
|
||||||
// Check n_fft and win)length
|
// Check n_fft and win)length
|
||||||
MS_LOG(INFO) << "n_fft is 40, win_length is 50.";
|
MS_LOG(INFO) << "n_fft is 40, win_length is 50.";
|
||||||
auto spectrogram_op_06 =
|
auto spectrogram_op_06 =
|
||||||
audio::Spectrogram({40, 50, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true});
|
audio::Spectrogram(40, 50, 20, -1, WindowType::kHann, 2.0, false, true, BorderType::kReflect, true);
|
||||||
ds06 = ds->Map({spectrogram_op_06});
|
ds06 = ds->Map({spectrogram_op_06});
|
||||||
EXPECT_NE(ds06, nullptr);
|
EXPECT_NE(ds06, nullptr);
|
||||||
|
|
||||||
|
@ -837,7 +837,7 @@ TEST_F(MindDataTestPipeline, TestTimeMaskingPipeline) {
|
||||||
std::vector<int64_t> expected = {2, 200};
|
std::vector<int64_t> expected = {2, 200};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["inputData"];
|
auto col = row["inputData"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
@ -901,7 +901,7 @@ TEST_F(MindDataTestPipeline, TestTimeStretchPipeline) {
|
||||||
std::vector<int64_t> expected = {2, freq, static_cast<int64_t>(std::ceil(400 / rate)), 2};
|
std::vector<int64_t> expected = {2, freq, static_cast<int64_t>(std::ceil(400 / rate)), 2};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["inputData"];
|
auto col = row["inputData"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||||
|
@ -965,7 +965,7 @@ TEST_F(MindDataTestPipeline, TestTrebleBiquadBasic) {
|
||||||
std::vector<int64_t> expected = {2, 200};
|
std::vector<int64_t> expected = {2, 200};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["waveform"];
|
auto col = row["waveform"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
@ -1032,7 +1032,7 @@ TEST_F(MindDataTestPipeline, TestVolPipeline) {
|
||||||
std::vector<int64_t> expected = {2, 200};
|
std::vector<int64_t> expected = {2, 200};
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (row.size() != 0) {
|
while (!row.empty()) {
|
||||||
auto col = row["inputData"];
|
auto col = row["inputData"];
|
||||||
ASSERT_EQ(col.Shape(), expected);
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
ASSERT_EQ(col.Shape().size(), 2);
|
ASSERT_EQ(col.Shape().size(), 2);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
Loading…
Reference in New Issue