forked from mindspore-Ecosystem/mindspore
shuffle take repeat bucketbatch buildvocab repeat shuffle take project concat rename node IR added
concat, bucketbatch project rename fix ci round 1 fix ci round 2 fix up fix ci
This commit is contained in:
parent
70bb0a842a
commit
d471552fc5
|
@ -41,19 +41,8 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
#endif
|
||||
// Dataset operator headers (in alphabetical order)
|
||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/rename_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/take_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
|
||||
// Sampler headers (in alphabetical order)
|
||||
|
@ -61,8 +50,21 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
|
||||
// IR nodes
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#endif
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
|
@ -1759,175 +1761,9 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
|
|||
#endif
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
BucketBatchByLengthNode::BucketBatchByLengthNode(
|
||||
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder)
|
||||
: column_names_(column_names),
|
||||
bucket_boundaries_(bucket_boundaries),
|
||||
bucket_batch_sizes_(bucket_batch_sizes),
|
||||
element_length_function_(element_length_function),
|
||||
pad_info_(pad_info),
|
||||
pad_to_bucket_boundary_(pad_to_bucket_boundary),
|
||||
drop_remainder_(drop_remainder) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<TensorOp> c_func;
|
||||
if (element_length_function_ != nullptr) {
|
||||
c_func = std::make_shared<CFuncOp>(element_length_function_);
|
||||
} else {
|
||||
c_func = nullptr;
|
||||
}
|
||||
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
c_func, pad_info_, pad_to_bucket_boundary_,
|
||||
drop_remainder_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status BucketBatchByLengthNode::ValidateParams() {
|
||||
if (element_length_function_ == nullptr && column_names_.size() != 1) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
|
||||
std::to_string(column_names_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
// Check bucket_boundaries: must be positive and strictly increasing
|
||||
if (bucket_boundaries_.empty()) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
for (int i = 0; i < bucket_boundaries_.size(); i++) {
|
||||
if (bucket_boundaries_[i] <= 0) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
|
||||
MS_LOG(ERROR)
|
||||
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
|
||||
<< i << " was: " << bucket_boundaries_[i];
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
|
||||
MS_LOG(ERROR)
|
||||
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
|
||||
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
|
||||
<< " respectively.";
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Check bucket_batch_sizes: must be positive
|
||||
if (bucket_batch_sizes_.empty()) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
|
||||
std::string err_msg =
|
||||
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
|
||||
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
|
||||
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
|
||||
: vocab_(vocab),
|
||||
columns_(columns),
|
||||
freq_range_(freq_range),
|
||||
top_k_(top_k),
|
||||
special_tokens_(special_tokens),
|
||||
special_first_(special_first) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build BuildVocabNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<BuildVocabOp> build_vocab_op;
|
||||
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
|
||||
special_first_, num_workers_, connector_que_size_);
|
||||
node_ops.push_back(build_vocab_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status BuildVocabNode::ValidateParams() {
|
||||
if (vocab_ == nullptr) {
|
||||
std::string err_msg = "BuildVocabNode: vocab is null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (top_k_ <= 0) {
|
||||
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
|
||||
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
|
||||
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
|
||||
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (!columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Function to build ConcatOp
|
||||
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
|
||||
this->children = datasets_;
|
||||
}
|
||||
|
||||
Status ConcatNode::ValidateParams() {
|
||||
if (datasets_.empty()) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns)
|
||||
|
@ -1984,110 +1820,6 @@ Status MapNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build ProjectOp
|
||||
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
Status ProjectNode::ValidateParams() {
|
||||
if (columns_.empty()) {
|
||||
std::string err_msg = "ProjectNode: No columns are specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to build RenameOp
|
||||
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns)
|
||||
: input_columns_(input_columns), output_columns_(output_columns) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
Status RenameNode::ValidateParams() {
|
||||
if (input_columns_.size() != output_columns_.size()) {
|
||||
std::string err_msg = "RenameNode: input and output columns must be the same size";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status RepeatNode::ValidateParams() {
|
||||
if (repeat_count_ <= 0 && repeat_count_ != -1) {
|
||||
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
|
||||
std::to_string(repeat_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor for ShuffleNode
|
||||
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
|
||||
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the ShuffleOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
|
||||
rows_per_buffer_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for ShuffleNode
|
||||
Status ShuffleNode::ValidateParams() {
|
||||
if (shuffle_size_ <= 1) {
|
||||
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor for SkipNode
|
||||
SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) {
|
||||
this->children.push_back(child);
|
||||
|
@ -2113,31 +1845,6 @@ Status SkipNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor for TakeNode
|
||||
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the TakeOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for TakeNode
|
||||
Status TakeNode::ValidateParams() {
|
||||
if (take_count_ <= 0 && take_count_ != -1) {
|
||||
std::string err_msg =
|
||||
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build ZipOp
|
||||
ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
|
||||
for (auto dataset : datasets_) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_subdirectory(source)
|
||||
add_library(engine-ir-datasetops OBJECT
|
||||
batch_node.cc)
|
||||
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
||||
batch_node.cc
|
||||
concat_node.cc
|
||||
project_node.cc
|
||||
rename_node.cc
|
||||
repeat_node.cc
|
||||
shuffle_node.cc
|
||||
take_node.cc
|
||||
)
|
||||
|
||||
if (NOT ENABLE_ANDROID)
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
||||
${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}
|
||||
bucket_batch_by_length_node.cc
|
||||
build_vocab_node.cc)
|
||||
endif ()
|
||||
|
||||
add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES})
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
BucketBatchByLengthNode::BucketBatchByLengthNode(
|
||||
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder)
|
||||
: column_names_(column_names),
|
||||
bucket_boundaries_(bucket_boundaries),
|
||||
bucket_batch_sizes_(bucket_batch_sizes),
|
||||
element_length_function_(element_length_function),
|
||||
pad_info_(pad_info),
|
||||
pad_to_bucket_boundary_(pad_to_bucket_boundary),
|
||||
drop_remainder_(drop_remainder) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<TensorOp> c_func;
|
||||
if (element_length_function_ != nullptr) {
|
||||
c_func = std::make_shared<CFuncOp>(element_length_function_);
|
||||
} else {
|
||||
c_func = nullptr;
|
||||
}
|
||||
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
c_func, pad_info_, pad_to_bucket_boundary_,
|
||||
drop_remainder_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status BucketBatchByLengthNode::ValidateParams() {
|
||||
if (element_length_function_ == nullptr && column_names_.size() != 1) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
|
||||
std::to_string(column_names_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
// Check bucket_boundaries: must be positive and strictly increasing
|
||||
if (bucket_boundaries_.empty()) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
for (int i = 0; i < bucket_boundaries_.size(); i++) {
|
||||
if (bucket_boundaries_[i] <= 0) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
|
||||
MS_LOG(ERROR)
|
||||
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
|
||||
<< i << " was: " << bucket_boundaries_[i];
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
|
||||
MS_LOG(ERROR)
|
||||
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
|
||||
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
|
||||
<< " respectively.";
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Check bucket_batch_sizes: must be positive
|
||||
if (bucket_batch_sizes_.empty()) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
|
||||
std::string err_msg =
|
||||
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class BucketBatchByLengthNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
|
||||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
/// \brief Destructor
|
||||
~BucketBatchByLengthNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> column_names_;
|
||||
std::vector<int32_t> bucket_boundaries_;
|
||||
std::vector<int32_t> bucket_batch_sizes_;
|
||||
std::function<TensorRow(TensorRow)> element_length_function_;
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
|
||||
bool pad_to_bucket_boundary_;
|
||||
bool drop_remainder_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
|
||||
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
|
||||
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
|
||||
: vocab_(vocab),
|
||||
columns_(columns),
|
||||
freq_range_(freq_range),
|
||||
top_k_(top_k),
|
||||
special_tokens_(special_tokens),
|
||||
special_first_(special_first) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build BuildVocabNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<BuildVocabOp> build_vocab_op;
|
||||
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
|
||||
special_first_, num_workers_, connector_que_size_);
|
||||
node_ops.push_back(build_vocab_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status BuildVocabNode::ValidateParams() {
|
||||
if (vocab_ == nullptr) {
|
||||
std::string err_msg = "BuildVocabNode: vocab is null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (top_k_ <= 0) {
|
||||
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
|
||||
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
|
||||
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
|
||||
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (!columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class BuildVocabNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
|
||||
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
|
||||
const std::vector<std::string> &special_tokens, bool special_first);
|
||||
|
||||
/// \brief Destructor
|
||||
~BuildVocabNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
std::vector<std::string> columns_;
|
||||
std::pair<int64_t, int64_t> freq_range_;
|
||||
int64_t top_k_;
|
||||
std::vector<std::string> special_tokens_;
|
||||
bool special_first_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Function to build ConcatOp
|
||||
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
|
||||
this->children = datasets_;
|
||||
}
|
||||
|
||||
Status ConcatNode::ValidateParams() {
|
||||
if (datasets_.empty()) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class ConcatNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
|
||||
|
||||
/// \brief Destructor
|
||||
~ConcatNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Dataset>> datasets_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Function to build ProjectOp
|
||||
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
Status ProjectNode::ValidateParams() {
|
||||
if (columns_.empty()) {
|
||||
std::string err_msg = "ProjectNode: No columns are specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
class ProjectNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~ProjectNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/rename_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Function to build RenameOp
|
||||
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns)
|
||||
: input_columns_(input_columns), output_columns_(output_columns) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
Status RenameNode::ValidateParams() {
|
||||
if (input_columns_.size() != output_columns_.size()) {
|
||||
std::string err_msg = "RenameNode: input and output columns must be the same size";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
class RenameNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~RenameNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status RepeatNode::ValidateParams() {
|
||||
if (repeat_count_ <= 0 && repeat_count_ != -1) {
|
||||
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
|
||||
std::to_string(repeat_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
class RepeatNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~RepeatNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t repeat_count_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for ShuffleNode
|
||||
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
|
||||
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the ShuffleOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
|
||||
rows_per_buffer_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for ShuffleNode
|
||||
Status ShuffleNode::ValidateParams() {
|
||||
if (shuffle_size_ <= 1) {
|
||||
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
class ShuffleNode : public Dataset {
|
||||
public:
|
||||
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);
|
||||
|
||||
~ShuffleNode() = default;
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t shuffle_size_;
|
||||
uint32_t shuffle_seed_;
|
||||
bool reset_every_epoch_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/take_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for TakeNode
|
||||
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the TakeOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for TakeNode
|
||||
Status TakeNode::ValidateParams() {
|
||||
if (take_count_ <= 0 && take_count_ != -1) {
|
||||
std::string err_msg =
|
||||
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
class TakeNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~TakeNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t take_count_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
|
|
@ -1201,85 +1201,6 @@ class VOCNode : public Dataset {
|
|||
// DERIVED DATASET CLASSES FOR DATASET OPS
|
||||
// (In alphabetical order)
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class BucketBatchByLengthNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
|
||||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
/// \brief Destructor
|
||||
~BucketBatchByLengthNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> column_names_;
|
||||
std::vector<int32_t> bucket_boundaries_;
|
||||
std::vector<int32_t> bucket_batch_sizes_;
|
||||
std::function<TensorRow(TensorRow)> element_length_function_;
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
|
||||
bool pad_to_bucket_boundary_;
|
||||
bool drop_remainder_;
|
||||
};
|
||||
|
||||
class BuildVocabNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
|
||||
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
|
||||
const std::vector<std::string> &special_tokens, bool special_first);
|
||||
|
||||
/// \brief Destructor
|
||||
~BuildVocabNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
std::vector<std::string> columns_;
|
||||
std::pair<int64_t, int64_t> freq_range_;
|
||||
int64_t top_k_;
|
||||
std::vector<std::string> special_tokens_;
|
||||
bool special_first_;
|
||||
};
|
||||
#endif
|
||||
|
||||
class ConcatNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
|
||||
|
||||
/// \brief Destructor
|
||||
~ConcatNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Dataset>> datasets_;
|
||||
};
|
||||
|
||||
class MapNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -1305,84 +1226,6 @@ class MapNode : public Dataset {
|
|||
std::vector<std::string> project_columns_;
|
||||
};
|
||||
|
||||
class ProjectNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~ProjectNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
||||
class RenameNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~RenameNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
};
|
||||
|
||||
class RepeatNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~RepeatNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t repeat_count_;
|
||||
};
|
||||
|
||||
class ShuffleNode : public Dataset {
|
||||
public:
|
||||
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);
|
||||
|
||||
~ShuffleNode() = default;
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t shuffle_size_;
|
||||
uint32_t shuffle_seed_;
|
||||
bool reset_every_epoch_;
|
||||
};
|
||||
|
||||
class SkipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -1403,26 +1246,6 @@ class SkipNode : public Dataset {
|
|||
int32_t skip_count_;
|
||||
};
|
||||
|
||||
class TakeNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~TakeNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t take_count_;
|
||||
};
|
||||
|
||||
class ZipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
|
|
@ -18,6 +18,13 @@
|
|||
#include "minddata/dataset/include/config.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -17,6 +17,11 @@
|
|||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -18,8 +18,17 @@
|
|||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
|
|
@ -16,10 +16,14 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::api;
|
||||
|
|
|
@ -16,8 +16,14 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -19,6 +19,11 @@
|
|||
#include "minddata/dataset/include/vision.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::BorderType;
|
||||
|
|
|
@ -18,8 +18,14 @@
|
|||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::BorderType;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
Loading…
Reference in New Issue