shuffle take repeat bucketbatch buildvocab repeat shuffle take project concat rename node IR added

concat, bucketbatch project rename

fix ci round 1

fix ci round 2

fix up

fix ci
This commit is contained in:
Zirui Wu 2020-10-21 16:15:49 -04:00
parent 70bb0a842a
commit d471552fc5
29 changed files with 1085 additions and 493 deletions

View File

@ -41,19 +41,8 @@
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#endif
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
// Sampler headers (in alphabetical order)
@ -61,8 +50,21 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
// IR nodes
// IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#endif
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/core/config_manager.h"
@ -1759,175 +1761,9 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
#endif
#ifndef ENABLE_ANDROID
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder)
: column_names_(column_names),
bucket_boundaries_(bucket_boundaries),
bucket_batch_sizes_(bucket_batch_sizes),
element_length_function_(element_length_function),
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) {
this->children.push_back(child);
}
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<TensorOp> c_func;
if (element_length_function_ != nullptr) {
c_func = std::make_shared<CFuncOp>(element_length_function_);
} else {
c_func = nullptr;
}
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
c_func, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_));
return node_ops;
}
Status BucketBatchByLengthNode::ValidateParams() {
if (element_length_function_ == nullptr && column_names_.size() != 1) {
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
std::to_string(column_names_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// Check bucket_boundaries: must be positive and strictly increasing
if (bucket_boundaries_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (bucket_boundaries_[i] <= 0) {
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
<< i << " was: " << bucket_boundaries_[i];
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
<< " respectively.";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
// Check bucket_batch_sizes: must be positive
if (bucket_batch_sizes_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
std::string err_msg =
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
: vocab_(vocab),
columns_(columns),
freq_range_(freq_range),
top_k_(top_k),
special_tokens_(special_tokens),
special_first_(special_first) {
this->children.push_back(child);
}
// Function to build BuildVocabNode
std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
node_ops.push_back(build_vocab_op);
return node_ops;
}
Status BuildVocabNode::ValidateParams() {
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (top_k_ <= 0) {
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (!columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
}
return Status::OK();
}
#endif
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}
Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops;
}
MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns)
@ -1984,110 +1820,6 @@ Status MapNode::ValidateParams() {
return Status::OK();
}
// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
this->children.push_back(child);
}
Status ProjectNode::ValidateParams() {
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return node_ops;
}
// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);
}
Status RenameNode::ValidateParams() {
if (input_columns_.size() != output_columns_.size()) {
std::string err_msg = "RenameNode: input and output columns must be the same size";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
}
std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return node_ops;
}
Status RepeatNode::ValidateParams() {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
}
// Function to build the ShuffleOp
std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_));
return node_ops;
}
// Function to validate the parameters for ShuffleNode
Status ShuffleNode::ValidateParams() {
if (shuffle_size_ <= 1) {
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
// Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) {
this->children.push_back(child);
@ -2113,31 +1845,6 @@ Status SkipNode::ValidateParams() {
return Status::OK();
}
// Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
}
// Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
return node_ops;
}
// Function to validate the parameters for TakeNode
Status TakeNode::ValidateParams() {
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg =
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
// Function to build ZipOp
ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
for (auto dataset : datasets_) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,22 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(source)
add_library(engine-ir-datasetops OBJECT
batch_node.cc)
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
batch_node.cc
concat_node.cc
project_node.cc
rename_node.cc
repeat_node.cc
shuffle_node.cc
take_node.cc
)
if (NOT ENABLE_ANDROID)
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}
bucket_batch_by_length_node.cc
build_vocab_node.cc)
endif ()
add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES})

View File

@ -0,0 +1,121 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder)
: column_names_(column_names),
bucket_boundaries_(bucket_boundaries),
bucket_batch_sizes_(bucket_batch_sizes),
element_length_function_(element_length_function),
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) {
this->children.push_back(child);
}
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<TensorOp> c_func;
if (element_length_function_ != nullptr) {
c_func = std::make_shared<CFuncOp>(element_length_function_);
} else {
c_func = nullptr;
}
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
c_func, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_));
return node_ops;
}
Status BucketBatchByLengthNode::ValidateParams() {
if (element_length_function_ == nullptr && column_names_.size() != 1) {
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
std::to_string(column_names_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// Check bucket_boundaries: must be positive and strictly increasing
if (bucket_boundaries_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (bucket_boundaries_[i] <= 0) {
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
<< i << " was: " << bucket_boundaries_[i];
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
<< " respectively.";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
// Check bucket_batch_sizes: must be positive
if (bucket_batch_sizes_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
std::string err_msg =
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class BucketBatchByLengthNode : public Dataset {
public:
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
/// \brief Destructor
~BucketBatchByLengthNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
std::function<TensorRow(TensorRow)> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_

View File

@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
: vocab_(vocab),
columns_(columns),
freq_range_(freq_range),
top_k_(top_k),
special_tokens_(special_tokens),
special_first_(special_first) {
this->children.push_back(child);
}
// Function to build BuildVocabNode
std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
node_ops.push_back(build_vocab_op);
return node_ops;
}
Status BuildVocabNode::ValidateParams() {
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (top_k_ <= 0) {
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (!columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class BuildVocabNode : public Dataset {
public:
/// \brief Constructor
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first);
/// \brief Destructor
~BuildVocabNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_;
std::pair<int64_t, int64_t> freq_range_;
int64_t top_k_;
std::vector<std::string> special_tokens_;
bool special_first_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}
Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops;
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class ConcatNode : public Dataset {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Destructor
~ConcatNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
this->children.push_back(child);
}
Status ProjectNode::ValidateParams() {
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return node_ops;
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class ProjectNode : public Dataset {
public:
/// \brief Constructor
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);
/// \brief Destructor
~ProjectNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::string> columns_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);
}
Status RenameNode::ValidateParams() {
if (input_columns_.size() != output_columns_.size()) {
std::string err_msg = "RenameNode: input and output columns must be the same size";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class RenameNode : public Dataset {
public:
/// \brief Constructor
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);
/// \brief Destructor
~RenameNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
}
std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return node_ops;
}
Status RepeatNode::ValidateParams() {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class RepeatNode : public Dataset {
public:
/// \brief Constructor
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);
/// \brief Destructor
~RepeatNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
int32_t repeat_count_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
}
// Function to build the ShuffleOp
std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_));
return node_ops;
}
// Function to validate the parameters for ShuffleNode
Status ShuffleNode::ValidateParams() {
if (shuffle_size_ <= 1) {
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class ShuffleNode : public Dataset {
public:
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);
~ShuffleNode() = default;
std::vector<std::shared_ptr<DatasetOp>> Build() override;
Status ValidateParams() override;
private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;
bool reset_every_epoch_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
}
// Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
return node_ops;
}
// Function to validate the parameters for TakeNode
Status TakeNode::ValidateParams() {
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg =
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class TakeNode : public Dataset {
public:
/// \brief Constructor
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);
/// \brief Destructor
~TakeNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
int32_t take_count_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_

View File

@ -1201,85 +1201,6 @@ class VOCNode : public Dataset {
// DERIVED DATASET CLASSES FOR DATASET OPS
// (In alphabetical order)
#ifndef ENABLE_ANDROID
class BucketBatchByLengthNode : public Dataset {
public:
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
/// \brief Destructor
~BucketBatchByLengthNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
std::function<TensorRow(TensorRow)> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;
};
class BuildVocabNode : public Dataset {
public:
/// \brief Constructor
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first);
/// \brief Destructor
~BuildVocabNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_;
std::pair<int64_t, int64_t> freq_range_;
int64_t top_k_;
std::vector<std::string> special_tokens_;
bool special_first_;
};
#endif
class ConcatNode : public Dataset {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Destructor
~ConcatNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};
class MapNode : public Dataset {
public:
/// \brief Constructor
@ -1305,84 +1226,6 @@ class MapNode : public Dataset {
std::vector<std::string> project_columns_;
};
class ProjectNode : public Dataset {
public:
/// \brief Constructor
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);
/// \brief Destructor
~ProjectNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::string> columns_;
};
class RenameNode : public Dataset {
public:
/// \brief Constructor
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);
/// \brief Destructor
~RenameNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
};
class RepeatNode : public Dataset {
public:
/// \brief Constructor
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);
/// \brief Destructor
~RepeatNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
int32_t repeat_count_;
};
class ShuffleNode : public Dataset {
public:
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);
~ShuffleNode() = default;
std::vector<std::shared_ptr<DatasetOp>> Build() override;
Status ValidateParams() override;
private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;
bool reset_every_epoch_;
};
class SkipNode : public Dataset {
public:
/// \brief Constructor
@ -1403,26 +1246,6 @@ class SkipNode : public Dataset {
int32_t skip_count_;
};
class TakeNode : public Dataset {
public:
/// \brief Constructor
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);
/// \brief Destructor
~TakeNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
int32_t take_count_;
};
class ZipNode : public Dataset {
public:
/// \brief Constructor

View File

@ -18,6 +18,13 @@
#include "minddata/dataset/include/config.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::ShuffleMode;
using mindspore::dataset::Tensor;

View File

@ -17,6 +17,11 @@
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;

View File

@ -18,8 +18,17 @@
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;

View File

@ -16,10 +16,14 @@
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
using namespace mindspore::dataset;
using namespace mindspore::dataset::api;

View File

@ -16,8 +16,14 @@
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;

View File

@ -16,8 +16,13 @@
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;

View File

@ -19,6 +19,11 @@
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::BorderType;

View File

@ -18,8 +18,14 @@
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::BorderType;
using mindspore::dataset::Tensor;