!8802 Boilerplate code for IR Tree optimizer

From: @nsyca
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-27 06:39:03 +08:00 committed by Gitee
commit bd8522aff7
86 changed files with 1942 additions and 622 deletions

View File

@ -568,8 +568,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size,
character_coverage, model_type, params);
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
@ -600,8 +600,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first) {
auto vocab = std::make_shared<Vocab>();
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens,
special_first);
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();

View File

@ -190,13 +190,12 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
: sp_minddataset_(std::move(sampler)) {}
#endif
bool PreBuiltSamplerObj::ValidateParams() { return true; }
@ -207,6 +206,13 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
#endif
return std::make_shared<PreBuiltSamplerObj>(sp_);
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object

View File

@ -30,8 +30,6 @@
namespace mindspore {
namespace dataset {
TensorOperation::TensorOperation() {}
/* ####################################### Validator Functions ############################################ */
Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
@ -231,7 +229,7 @@ std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
// RandomApplyOperation
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
: transforms_(transforms), prob_(prob) {}
: TensorOperation(true), transforms_(transforms), prob_(prob) {}
Status RandomApplyOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
@ -248,7 +246,7 @@ std::shared_ptr<TensorOp> RandomApplyOperation::Build() {
// RandomChoiceOperation
RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
: transforms_(transforms) {}
: TensorOperation(true), transforms_(transforms) {}
Status RandomChoiceOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));

View File

@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> &degrees
scale_range_(scale_range),
shear_ranges_(shear_ranges),
interpolation_(interpolation),
fill_value_(fill_value) {}
fill_value_(fill_value) {
random_op_ = true;
}
Status RandomAffineOperation::ValidateParams() {
// Degrees
@ -867,7 +869,7 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
}
// RandomColorOperation.
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; }
Status RandomColorOperation::ValidateParams() {
// Do some input validation.
@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() {
// RandomColorAdjustOperation.
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
std::vector<float> saturation, std::vector<float> hue)
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {
random_op_ = true;
}
Status RandomColorAdjustOperation::ValidateParams() {
// brightness
@ -1012,11 +1016,14 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
// RandomCropOperation
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
std::vector<uint8_t> fill_value, BorderType padding_mode)
: size_(size),
: TensorOperation(true),
size_(size),
padding_(padding),
pad_if_needed_(pad_if_needed),
fill_value_(fill_value),
padding_mode_(padding_mode) {}
padding_mode_(padding_mode) {
random_op_ = true;
}
Status RandomCropOperation::ValidateParams() {
// size
@ -1083,7 +1090,12 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio,
InterpolationMode interpolation, int32_t max_attempts)
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
: TensorOperation(true),
size_(size),
scale_(scale),
ratio_(ratio),
interpolation_(interpolation),
max_attempts_(max_attempts) {}
Status RandomCropDecodeResizeOperation::ValidateParams() {
// size
@ -1176,7 +1188,8 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding,
bool pad_if_needed, std::vector<uint8_t> fill_value,
BorderType padding_mode)
: size_(size),
: TensorOperation(true),
size_(size),
padding_(padding),
pad_if_needed_(pad_if_needed),
fill_value_(fill_value),
@ -1245,7 +1258,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() {
}
// RandomHorizontalFlipOperation
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability)
: TensorOperation(true), probability_(probability) {}
Status RandomHorizontalFlipOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_));
@ -1260,7 +1274,7 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
// RandomHorizontalFlipWithBBoxOperation
RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability)
: probability_(probability) {}
: TensorOperation(true), probability_(probability) {}
Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_));
@ -1275,7 +1289,8 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipWithBBoxOperation::Build() {
}
// RandomPosterizeOperation
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range)
: TensorOperation(true), bit_range_(bit_range) {}
Status RandomPosterizeOperation::ValidateParams() {
if (bit_range_.size() != 2) {
@ -1309,7 +1324,7 @@ std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
}
// RandomResizeOperation
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : size_(size) {}
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : TensorOperation(true), size_(size) {}
Status RandomResizeOperation::ValidateParams() {
// size
@ -1343,7 +1358,8 @@ std::shared_ptr<TensorOp> RandomResizeOperation::Build() {
}
// RandomResizeWithBBoxOperation
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) : size_(size) {}
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size)
: TensorOperation(true), size_(size) {}
Status RandomResizeWithBBoxOperation::ValidateParams() {
// size
@ -1380,7 +1396,12 @@ std::shared_ptr<TensorOp> RandomResizeWithBBoxOperation::Build() {
RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio, InterpolationMode interpolation,
int32_t max_attempts)
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
: TensorOperation(true),
size_(size),
scale_(scale),
ratio_(ratio),
interpolation_(interpolation),
max_attempts_(max_attempts) {}
Status RandomResizedCropOperation::ValidateParams() {
// size
@ -1536,7 +1557,8 @@ std::shared_ptr<TensorOp> RandomResizedCropWithBBoxOperation::Build() {
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
bool expand, std::vector<float> center,
std::vector<uint8_t> fill_value)
: degrees_(degrees),
: TensorOperation(true),
degrees_(degrees),
interpolation_mode_(interpolation_mode),
expand_(expand),
center_(center),
@ -1603,7 +1625,7 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
// RandomSelectSubpolicyOperation.
RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation(
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy)
: policy_(policy) {}
: TensorOperation(true), policy_(policy) {}
Status RandomSelectSubpolicyOperation::ValidateParams() {
if (policy_.empty()) {
@ -1650,7 +1672,8 @@ std::shared_ptr<TensorOp> RandomSelectSubpolicyOperation::Build() {
}
// Function to create RandomSharpness.
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees)
: TensorOperation(true), degrees_(degrees) {}
Status RandomSharpnessOperation::ValidateParams() {
if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) {
@ -1674,7 +1697,8 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
}
// RandomSolarizeOperation.
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold)
: TensorOperation(true), threshold_(threshold) {}
Status RandomSolarizeOperation::ValidateParams() {
if (threshold_.size() != 2) {
@ -1705,7 +1729,8 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
}
// RandomVerticalFlipOperation
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability)
: TensorOperation(true), probability_(probability) {}
Status RandomVerticalFlipOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_));
@ -1720,7 +1745,7 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
// RandomVerticalFlipWithBBoxOperation
RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability)
: probability_(probability) {}
: TensorOperation(true), probability_(probability) {}
Status RandomVerticalFlipWithBBoxOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_));

View File

@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
build_sentence_piece_vocab_node.cc
build_vocab_node.cc
concat_node.cc
epoch_ctrl_node.cc
filter_node.cc
map_node.cc
project_node.cc
rename_node.cc
repeat_node.cc
root_node.cc
shuffle_node.cc
skip_node.cc
sync_wait_node.cc

View File

@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, boo
batch_size_func_(batch_size_func),
batch_map_func_(batch_map_func),
pad_map_(pad_map) {
this->children.push_back(child);
this->AddChild(child);
}
#endif
// constructor #2, called by C++ API
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder)
: batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BatchNode::Copy() {
#ifdef ENABLE_PYTHON
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_,
col_order_, batch_size_func_, batch_map_func_, pad_map_);
#else
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_);
#endif
return node;
}
void BatchNode::Print(std::ostream &out) const {
out << Name() + "(batch_size:" + std::to_string(batch_size_) +
" drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")";
}
Status BatchNode::ValidateParams() {

View File

@ -44,6 +44,18 @@ class BatchNode : public DatasetNode {
/// \brief Destructor
~BatchNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBatchNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_);
return node;
}
void BucketBatchByLengthNode::Print(std::ostream &out) const {
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)";
}
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {

View File

@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \brief Destructor
~BucketBatchByLengthNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBucketBatchByLengthNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -22,6 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> chil
character_coverage_(character_coverage),
model_type_(model_type),
params_(params) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BuildSentenceVocabNode::Copy() {
auto node = std::make_shared<BuildSentenceVocabNode>(nullptr, vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_);
return node;
}
void BuildSentenceVocabNode::Print(std::ostream &out) const {
out << Name() + "<vocab>," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) +
",...)";
}
// Function to build BuildSentenceVocabNode
@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
}
// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \brief Destructor
~BuildSentenceVocabNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBuildSentencePieceVocabNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::shared_ptr<SentencePieceVocab> vocab_;
std::vector<std::string> col_names_;

View File

@ -22,7 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_p
top_k_(top_k),
special_tokens_(special_tokens),
special_first_(special_first) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> BuildVocabNode::Copy() {
auto node =
std::make_shared<BuildVocabNode>(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_);
return node;
}
void BuildVocabNode::Print(std::ostream &out) const {
out << Name() + "(<vocab>," + "columns:" + PrintColumns(columns_) + ",...)";
}
// Function to build BuildVocabNode
@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status BuildVocabNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
}
// Visitor accepting method for NodePass
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode {
/// \brief Destructor
~BuildVocabNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBuildVocabNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_;

View File

@ -22,7 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
: sampler_(sampler),
children_flag_and_nums_(children_flag_and_nums),
children_start_end_index_(children_start_end_index) {
this->children = datasets;
for (auto const &child : datasets) AddChild(child);
}
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
// create an empty vector to copy a concat
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>());
return node;
}
void ConcatNode::Print(std::ostream &out) const { out << Name(); }
Status ConcatNode::ValidateParams() {
if (children.size() < 2) {
if (children_.size() < 2) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(children.begin(), children.end(), nullptr) != children.end()) {
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
@ -73,5 +81,16 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
return node_ops;
}
// Visitor accepting method for NodePass
Status ConcatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ConcatNode>(), modified);
}
// Visitor accepting method for NodePass
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode {
/// \brief Destructor
~ConcatNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kConcatNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode {
std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
};
} // namespace dataset

View File

@ -233,14 +233,92 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this();
}
DatasetNode::DatasetNode() {
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
build_status = Status::OK(); // remove me after changing return val of Build()
}
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
std::shared_ptr<DatasetNode> new_node = this->Copy();
for (const auto &child : children_) {
new_node->AddChild(child->DeepCopy());
}
return new_node;
}
std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
std::string me;
if (columns.empty()) {
me = "<nil>";
} else {
me = "[";
auto i = 0;
for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
me += *it;
if (i < columns.size() - 1) {
me += ", ";
} else {
me += "]";
}
}
}
return me;
}
void DatasetNode::PrintTree(std::ostream &out) const {
int level = 0;
PrintNode(out, &level);
}
void DatasetNode::PrintNode(std::ostream &out, int *level) const {
const std::string prefix = "+-";
const std::string indent = " ";
out << prefix;
Print(out);
for (const auto &c : this->Children()) {
out << '\n';
++(*level);
for (auto i = 0; i < *level; i++) {
out << indent;
}
c->PrintNode(out, level);
--(*level);
}
}
// Add a node as a child, node's parent needs to be nullptr
// this function will allow child to be a nullptr, in which case it will simply skip
void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
if (child != nullptr && child->parent_ == nullptr) {
children_.push_back(child);
child->parent_ = this;
} else if (child != nullptr) {
MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr.";
}
}
// Remove this node from its parent. Add the child of this node to its parent.
// for now, this remove is limited to node with a single child or no child
Status DatasetNode::Remove() {
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent.");
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child.");
if (children_.empty()) { // I am a leaf node, remove me from my parent's children list
parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()),
parent_->children_.end()); // removal using "erase remove idiom"
} else { // replace my position in my parent's children list with my single child
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent
*itr = std::move(children_[0]); // replace me in my parent's children list with my single child
children_.clear(); // release my single child from my children list
}
parent_ = nullptr;
return Status::OK();
}
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified);
}
Status DatasetNode::GetShardId(int32_t *shard_id) {
if (!Children().empty()) {
// Get shard id from the child node
return Children()[0]->GetShardId(shard_id);
} else {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
}
}
// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SourceNode>(), modified);
}
// Visitor accepting method for NodePass
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -42,6 +42,45 @@ class NodePass;
} \
} while (false)
// Names for non-leaf IR node
constexpr char kBatchNode[] = "Batch";
constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
constexpr char kBuildVocabNode[] = "BuildVocab";
constexpr char kConcatNode[] = "Concat";
constexpr char kDatasetNode[] = "Dataset";
constexpr char kEpochCtrlNode[] = "EpochCtrl";
constexpr char kFilterNode[] = "Filter";
constexpr char kMapNode[] = "Map";
constexpr char kProjectNode[] = "Project";
constexpr char kRenameNode[] = "Rename";
constexpr char kRepeatNode[] = "Repeat";
constexpr char kRootNode[] = "Top";
constexpr char kShuffleNode[] = "Shuffle";
constexpr char kSkipNode[] = "Skip";
constexpr char kSyncWaitNode[] = "SyncWait";
constexpr char kTakeNode[] = "Take";
constexpr char kTransferNode[] = "Transfer";
constexpr char kZipNode[] = "Zip";
// Names for leaf IR node
constexpr char kAlbumNode[] = "AlbumDataset";
constexpr char kCelebANode[] = "CelebADataset";
constexpr char kCifar100Node[] = "Cifar100Dataset";
constexpr char kCifar10Node[] = "Cifar10Dataset";
constexpr char kCLUENode[] = "CLUEDataset";
constexpr char kCocoNode[] = "CocoDataset";
constexpr char kCSVNode[] = "CSVDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";
constexpr char kRandomNode[] = "RandomDataset";
constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kVOCNode[] = "VOCDataset";
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);
@ -75,6 +114,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data
/// \return Shared pointer to the current Sampler.
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id);
// The base class of all IR nodes
class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
public:
/// \brief Constructor
@ -87,6 +127,36 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Destructor
~DatasetNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
/// \brief Pure virtual function to print the description
/// \param out - The output stream to write output to
virtual void Print(std::ostream &out) const = 0;
/// \brief Pure virtual function to make a new copy of the node
/// \return The new copy of the node
virtual std::shared_ptr<DatasetNode> Copy() = 0;
/// \brief Print the IR tree to output stream
/// \param out - The output stream to write output to
void PrintTree(std::ostream &out) const;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out - reference to the output stream being overloaded
/// \param dO - reference to the DatasetOp to display
/// \return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) {
node.PrintTree(out);
return out;
}
/// \brief Make a new copy of the tree from the current node
/// \return The new copy of the tree
std::shared_ptr<DatasetNode> DeepCopy();
/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
/// \return The list of shared pointers to the newly created DatasetOps
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
@ -95,17 +165,38 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0;
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; }
/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id);
/// \brief Getter function for child nodes
/// \return Child nodes
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
/// \brief Establish the parent-child relationship between this node and its child.
void AddChild(std::shared_ptr<DatasetNode> child);
/// \brief detach this node from its parent, add its child (if any) to its parent
/// \return error code, return error if node has more than 1 children
Status Remove();
/// \brief Check if this node has cache
/// \return True if the data of this node will be cached
const bool IsCached() const { return (cache_ != nullptr); }
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
/// \return A shared_ptr casted to the derived class
template <typename Derived>
std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this());
}
/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
/// visit on the way back up the tree after its descendants are visited.
@ -129,17 +220,123 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
Status BuildStatus() { return build_status; }
protected:
std::vector<std::shared_ptr<DatasetNode>> children;
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_;
std::shared_ptr<DatasetCache> cache_;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
Status build_status; // remove me after changing return val of Build()
std::string PrintColumns(const std::vector<std::string> &columns) const;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
void PrintNode(std::ostream &out, int *level) const;
};
// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
class SourceNode : public DatasetNode {
public:
/// \brief Constructor
SourceNode() : DatasetNode() {}
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}
/// \brief Destructor
~SourceNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }
protected:
bool mappable_;
};
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
class MappableSourceNode : public SourceNode {
public:
/// \brief Constructor
MappableSourceNode() : SourceNode() { mappable_ = true; }
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
mappable_ = true;
}
/// \brief Destructor
~MappableSourceNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
class NonMappableSourceNode : public SourceNode {
public:
/// \brief Constructor
NonMappableSourceNode() : SourceNode() { mappable_ = false; }
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
mappable_ = false;
}
/// \brief Destructor
~NonMappableSourceNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
// NonLeafNode represents operations over data in a pipeline.
class NonLeafNode : public DatasetNode {
public:
/// \brief Constructor
NonLeafNode() = default;
/// \brief Destructor
~NonLeafNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
// SinkNode represents the end node of a pipeline where the data is pushed out
class SinkNode : public DatasetNode {
public:
/// \brief Constructor
SinkNode() = default;
/// \brief Destructor
~SinkNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for EpochCtrlNode
EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) {
// The root node's parent must set to null pointer.
this->AddChild(child);
}
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_);
return node;
}
void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; }
// Function to build the EpochCtrlOp
std::vector<std::shared_ptr<DatasetOp>> EpochCtrlNode::Build() {
// A dummy vector
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<EpochCtrlOp>(num_epochs_));
return node_ops;
}
// Function to validate the parameters for EpochCtrlNode
Status EpochCtrlNode::ValidateParams() {
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_.size() != 1 || children_[0] == nullptr) {
std::string err_msg = "Internal error: epoch control node should have one child node";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class EpochCtrlNode : public DatasetNode {
public:
/// \brief Constructor
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
/// \brief Destructor
~EpochCtrlNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kEpochCtrlNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
int32_t num_epochs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_

View File

@ -21,7 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -31,7 +31,16 @@ namespace dataset {
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns)
: predicate_(predicate), input_columns_(input_columns) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> FilterNode::Copy() {
auto node = std::make_shared<FilterNode>(nullptr, predicate_, input_columns_);
return node;
}
void FilterNode::Print(std::ostream &out) const {
out << Name() + "(<predicate>," + "input_cols:" + PrintColumns(input_columns_) + ")";
}
std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status FilterNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<FilterNode>(), modified);
}
// Visitor accepting method for NodePass
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -35,6 +35,18 @@ class FilterNode : public DatasetNode {
/// \brief Destructor
~FilterNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kFilterNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -43,6 +55,18 @@ class FilterNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::shared_ptr<TensorOp> predicate_;
std::vector<std::string> input_columns_;

View File

@ -22,6 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
project_columns_(project_columns),
DatasetNode(std::move(cache)),
callbacks_(callbacks) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> MapNode::Copy() {
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_,
callbacks_);
return node;
}
void MapNode::Print(std::ostream &out) const {
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
",<project_cols>" + ",...)";
}
std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
@ -93,5 +105,16 @@ Status MapNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status MapNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<MapNode>(), modified);
}
// Visitor accepting method for NodePass
Status MapNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<MapNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -37,6 +37,18 @@ class MapNode : public DatasetNode {
/// \brief Destructor
~MapNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMapNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -45,6 +57,23 @@ class MapNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Getter of tensor operations
/// \return Vector of operations the Map node will process
const auto &TensorOperations() const { return operations_; }
auto &TensorOperations() { return operations_; }
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::vector<std::shared_ptr<TensorOperation>> operations_;
std::vector<std::string> input_columns_;

View File

@ -29,9 +29,16 @@ namespace dataset {
// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
: columns_(columns) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> ProjectNode::Copy() {
auto node = std::make_shared<ProjectNode>(nullptr, this->columns_);
return node;
}
void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; }
Status ProjectNode::ValidateParams() {
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";

View File

@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode {
/// \brief Destructor
~ProjectNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kProjectNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -30,7 +30,16 @@ namespace dataset {
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> RenameNode::Copy() {
auto node = std::make_shared<RenameNode>(nullptr, input_columns_, output_columns_);
return node;
}
void RenameNode::Print(std::ostream &out) const {
out << Name() + "(input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + ")";
}
Status RenameNode::ValidateParams() {

View File

@ -35,6 +35,18 @@ class RenameNode : public DatasetNode {
/// \brief Destructor
~RenameNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRenameNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -21,15 +21,22 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> RepeatNode::Copy() {
auto node = std::make_shared<RepeatNode>(nullptr, this->repeat_count_);
return node;
}
void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; }
std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
@ -49,5 +56,16 @@ Status RepeatNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status RepeatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RepeatNode>(), modified);
}
// Visitor accepting method for NodePass
Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -36,6 +36,18 @@ class RepeatNode : public DatasetNode {
/// \brief Destructor
~RepeatNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRepeatNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -44,6 +56,18 @@ class RepeatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
int32_t repeat_count_;
};

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for RootNode
RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) {
// The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode)
AddChild(child);
}
std::shared_ptr<DatasetNode> RootNode::Copy() {
auto node = std::make_shared<RootNode>(nullptr, num_epochs_);
return node;
}
void RootNode::Print(std::ostream &out) const { out << Name(); }
std::vector<std::shared_ptr<DatasetOp>> RootNode::Build() {
// root node doesn't build a runtime Op. this function should return Status::Error when called.
return {};
}
// Function to validate the parameters for RootNode
Status RootNode::ValidateParams() {
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (parent_ != nullptr) {
std::string err_msg = "Internal error: root node should not have a parent";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_.size() != 1) {
std::string err_msg = "Internal error: root node should have one child node";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_[0] == nullptr) {
std::string err_msg = "Internal error: root node's child is a null pointer";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
// Visitor accepting method for NodePass
Status RootNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RootNode>(), modified);
}
// Visitor accepting method for NodePass
Status RootNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RootNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class RootNode : public DatasetNode {
public:
/// \brief Constructor
RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
/// \brief Destructor
~RootNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRootNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Getter of number of epochs
int32_t num_epochs() { return num_epochs_; }
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
int32_t num_epochs_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_

View File

@ -29,7 +29,17 @@ namespace dataset {
// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> ShuffleNode::Copy() {
auto node = std::make_shared<ShuffleNode>(nullptr, shuffle_size_, reset_every_epoch_);
return node;
}
void ShuffleNode::Print(std::ostream &out) const {
out << Name() + "(shuffle_size:" + std::to_string(shuffle_size_) +
",reset_every_epoch:" + (reset_every_epoch_ ? "true" : "false") + ")";
}
// Function to build the ShuffleOp

View File

@ -34,6 +34,18 @@ class ShuffleNode : public DatasetNode {
~ShuffleNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kShuffleNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
std::vector<std::shared_ptr<DatasetOp>> Build() override;
Status ValidateParams() override;

View File

@ -27,10 +27,15 @@ namespace mindspore {
namespace dataset {
// Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) {
this->children.push_back(child);
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); }
std::shared_ptr<DatasetNode> SkipNode::Copy() {
auto node = std::make_shared<SkipNode>(nullptr, skip_count_);
return node;
}
void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + std::to_string(skip_count_) + ")"; }
// Function to build the SkipOp
std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create

View File

@ -34,6 +34,18 @@ class SkipNode : public DatasetNode {
/// \brief Destructor
~SkipNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kSkipNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -32,13 +32,23 @@ namespace dataset {
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
schema_path_(data_schema),
column_names_(column_names),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> AlbumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
return node;
}
void AlbumNode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}
Status AlbumNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
class AlbumNode : public DatasetNode {
class AlbumNode : public MappableSourceNode {
public:
/// \brief Constructor
AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
@ -36,6 +36,18 @@ class AlbumNode : public DatasetNode {
/// \brief Destructor
~AlbumNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kAlbumNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create a runtime dataset op object from this class
/// \return shared pointer to the newly created DatasetOp
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -31,13 +31,23 @@ namespace dataset {
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
sampler_(sampler),
decode_(decode),
extensions_(extensions) {}
std::shared_ptr<DatasetNode> CelebANode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
return node;
}
void CelebANode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}
Status CelebANode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));

View File

@ -28,7 +28,7 @@
namespace mindspore {
namespace dataset {
class CelebANode : public DatasetNode {
class CelebANode : public MappableSourceNode {
public:
/// \brief Constructor
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
@ -37,6 +37,18 @@ class CelebANode : public DatasetNode {
/// \brief Destructor
~CelebANode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCelebANode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -30,7 +30,17 @@ namespace dataset {
// Constructor for Cifar100Node
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
return node;
}
void Cifar100Node::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}
Status Cifar100Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
class Cifar100Node : public DatasetNode {
class Cifar100Node : public MappableSourceNode {
public:
/// \brief Constructor
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
@ -35,6 +35,18 @@ class Cifar100Node : public DatasetNode {
/// \brief Destructor
~Cifar100Node() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCifar100Node; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -30,7 +30,17 @@ namespace dataset {
// Constructor for Cifar10Node
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
return node;
}
void Cifar10Node::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}
Status Cifar10Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
class Cifar10Node : public DatasetNode {
class Cifar10Node : public MappableSourceNode {
public:
/// \brief Constructor
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
@ -35,6 +35,18 @@ class Cifar10Node : public DatasetNode {
/// \brief Destructor
~Cifar10Node() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCifar10Node; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -32,7 +32,7 @@ namespace dataset {
// Constructor for CLUENode
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(clue_files),
task_(task),
usage_(usage),
@ -41,6 +41,17 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task,
num_shards_(num_shards),
shard_id_(shard_id) {}
std::shared_ptr<DatasetNode> CLUENode::Copy() {
auto node =
std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void CLUENode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." +
",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")";
}
Status CLUENode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));

View File

@ -28,7 +28,7 @@ namespace dataset {
/// \class CLUENode
/// \brief A Dataset derived class to represent CLUE dataset
class CLUENode : public DatasetNode {
class CLUENode : public NonMappableSourceNode {
public:
/// \brief Constructor
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
@ -37,6 +37,18 @@ class CLUENode : public DatasetNode {
/// \brief Destructor
~CLUENode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCLUENode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -30,13 +30,21 @@ namespace dataset {
// Constructor for CocoNode
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
annotation_file_(annotation_file),
task_(task),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> CocoNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
return node;
}
void CocoNode::Print(std::ostream &out) const { out << Name(); }
Status CocoNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
class CocoNode : public DatasetNode {
class CocoNode : public MappableSourceNode {
public:
/// \brief Constructor
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
@ -35,6 +35,18 @@ class CocoNode : public DatasetNode {
/// \brief Destructor
~CocoNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCocoNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(csv_files),
field_delim_(field_delim),
column_defaults_(column_defaults),
@ -43,6 +43,17 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
num_shards_(num_shards),
shard_id_(shard_id) {}
std::shared_ptr<DatasetNode> CSVNode::Copy() {
auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_,
shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void CSVNode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." +
",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")";
}
Status CSVNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_));

View File

@ -47,7 +47,7 @@ class CsvRecord : public CsvBase {
T value;
};
class CSVNode : public DatasetNode {
class CSVNode : public NonMappableSourceNode {
public:
/// \brief Constructor
CSVNode(const std::vector<std::string> &dataset_files, char field_delim,
@ -58,6 +58,18 @@ class CSVNode : public DatasetNode {
/// \brief Destructor
~CSVNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCSVNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -28,7 +28,19 @@ namespace dataset {
GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types)
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
: MappableSourceNode(),
generator_function_(generator_function),
column_names_(column_names),
column_types_(column_types) {}
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
return node;
}
void GeneratorNode::Print(std::ostream &out) const {
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
}
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {}

View File

@ -26,10 +26,9 @@
namespace mindspore {
namespace dataset {
/// \class GeneratorNode
/// \brief A Dataset derived class to represent GeneratorNode dataset
class GeneratorNode : public DatasetNode {
class GeneratorNode : public MappableSourceNode {
public:
/// \brief Constructor
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
@ -41,6 +40,18 @@ class GeneratorNode : public DatasetNode {
/// \brief Destructor
~GeneratorNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kGeneratorNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -33,13 +33,24 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
bool recursive, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing,
std::shared_ptr<DatasetCache> cache = nullptr)
: dataset_dir_(dataset_dir),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
decode_(decode),
sampler_(sampler),
recursive_(recursive),
class_indexing_(class_indexing),
exts_(extensions),
DatasetNode(std::move(cache)) {}
exts_(extensions) {}
std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node =
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
return node;
}
void ImageFolderNode::Print(std::ostream &out) const {
out << Name() + "(path:" + dataset_dir_ + ",decode:" + (decode_ ? "true" : "false") + ",...)";
}
Status ImageFolderNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));

View File

@ -31,7 +31,7 @@ namespace dataset {
/// \class ImageFolderNode
/// \brief A Dataset derived class to represent ImageFolder dataset
class ImageFolderNode : public DatasetNode {
class ImageFolderNode : public MappableSourceNode {
public:
/// \brief Constructor
ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
@ -41,6 +41,18 @@ class ImageFolderNode : public DatasetNode {
/// \brief Destructor
~ImageFolderNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kImageFolderNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -32,13 +32,30 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_file_(dataset_file),
usage_(usage),
decode_(decode),
class_index_(class_indexing),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> ManifestNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
return node;
}
void ManifestNode::Print(std::ostream &out) const {
out << Name() + "(file:" + dataset_file_;
if (sampler_ != nullptr) {
out << ",sampler";
}
if (cache_ != nullptr) {
out << ",cache";
}
out << ")";
}
Status ManifestNode::ValidateParams() {
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
for (char c : dataset_file_) {

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace dataset {
class ManifestNode : public DatasetNode {
class ManifestNode : public MappableSourceNode {
public:
/// \brief Constructor
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
@ -36,6 +36,18 @@ class ManifestNode : public DatasetNode {
/// \brief Destructor
~ManifestNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kManifestNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -30,7 +30,8 @@ namespace dataset {
MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
: dataset_file_(std::string()),
: MappableSourceNode(),
dataset_file_(std::string()),
dataset_files_(dataset_files),
search_for_pattern_(false),
columns_list_(columns_list),
@ -41,7 +42,8 @@ MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const
MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
: dataset_file_(dataset_file),
: MappableSourceNode(),
dataset_file_(dataset_file),
dataset_files_({}),
search_for_pattern_(true),
columns_list_(columns_list),
@ -50,6 +52,19 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st
sample_bytes_({}),
num_padded_(num_padded) {}
std::shared_ptr<DatasetNode> MindDataNode::Copy() {
std::shared_ptr<MindDataNode> node;
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
if (dataset_files_.empty()) {
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
} else {
node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_);
}
return node;
}
void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; }
Status MindDataNode::ValidateParams() {
if (!search_for_pattern_ && dataset_files_.size() > 4096) {
std::string err_msg =

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace dataset {
class MindDataNode : public DatasetNode {
class MindDataNode : public MappableSourceNode {
public:
/// \brief Constructor
MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
@ -40,6 +40,18 @@ class MindDataNode : public DatasetNode {
/// \brief Destructor
~MindDataNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMindDataNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -29,7 +29,15 @@ namespace dataset {
MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<DatasetNode> MnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
return node;
}
void MnistNode::Print(std::ostream &out) const { out << Name(); }
Status MnistNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
class MnistNode : public DatasetNode {
class MnistNode : public MappableSourceNode {
public:
/// \brief Constructor
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
@ -35,6 +35,18 @@ class MnistNode : public DatasetNode {
/// \brief Destructor
~MnistNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMnistNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -27,6 +27,18 @@
namespace mindspore {
namespace dataset {
std::shared_ptr<DatasetNode> RandomNode::Copy() {
std::shared_ptr<RandomNode> node;
if (schema_ != nullptr) {
node = std::make_shared<RandomNode>(total_rows_, schema_, columns_list_, cache_);
} else {
node = std::make_shared<RandomNode>(total_rows_, schema_path_, columns_list_, cache_);
}
return node;
}
void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + std::to_string(total_rows_) + ",...)"; }
// ValidateParams for RandomNode
Status RandomNode::ValidateParams() {
if (total_rows_ < 0) {

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace dataset {
class RandomNode : public DatasetNode {
class RandomNode : public NonMappableSourceNode {
public:
// Some constants to provide limits to random generation.
static constexpr int32_t kMaxNumColumns = 4;
@ -37,7 +37,7 @@ class RandomNode : public DatasetNode {
/// \brief Constructor
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
total_rows_(total_rows),
schema_path_(""),
schema_(std::move(schema)),
@ -46,14 +46,27 @@ class RandomNode : public DatasetNode {
/// \brief Constructor
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
total_rows_(total_rows),
schema_path_(schema_path),
schema_(nullptr),
columns_list_(columns_list) {}
/// \brief Destructor
~RandomNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRandomNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -31,13 +31,23 @@ namespace dataset {
// Constructor for TextFileNode
TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(dataset_files),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {}
std::shared_ptr<DatasetNode> TextFileNode::Copy() {
auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void TextFileNode::Print(std::ostream &out) const {
out << Name() + "(file:..." + ",num_shards:" + std::to_string(num_shards_) +
",shard_id:" + std::to_string(shard_id_) + ",cache:" + ((cache_ != nullptr) ? "true" : "false") + ",...)";
}
Status TextFileNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));

View File

@ -28,7 +28,7 @@ namespace dataset {
/// \class TextFileNode
/// \brief A Dataset derived class to represent TextFile dataset
class TextFileNode : public DatasetNode {
class TextFileNode : public NonMappableSourceNode {
public:
/// \brief Constructor
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
@ -37,6 +37,18 @@ class TextFileNode : public DatasetNode {
/// \brief Destructor
~TextFileNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTextFileNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -30,6 +30,23 @@
namespace mindspore {
namespace dataset {
std::shared_ptr<DatasetNode> TFRecordNode::Copy() {
std::shared_ptr<TFRecordNode> node;
if (schema_obj_ != nullptr) {
node = std::make_shared<TFRecordNode>(dataset_files_, schema_obj_, columns_list_, num_samples_, shuffle_,
num_shards_, shard_id_, shard_equal_rows_, cache_);
} else {
node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_,
num_shards_, shard_id_, shard_equal_rows_, cache_);
}
return node;
}
void TFRecordNode::Print(std::ostream &out) const {
out << Name() + "(num_samples:" + std::to_string(num_samples_) + ",num_shards:" + std::to_string(num_shards_) +
",shard_id:" + std::to_string(shard_id_) + ",...)";
}
// Validator for TFRecordNode
Status TFRecordNode::ValidateParams() {
if (dataset_files_.empty()) {

View File

@ -29,14 +29,14 @@ namespace dataset {
/// \class TFRecordNode
/// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordNode : public DatasetNode {
class TFRecordNode : public NonMappableSourceNode {
public:
/// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(dataset_files),
schema_path_(schema),
columns_list_(columns_list),
@ -51,7 +51,7 @@ class TFRecordNode : public DatasetNode {
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(dataset_files),
schema_obj_(schema),
columns_list_(columns_list),
@ -64,6 +64,18 @@ class TFRecordNode : public DatasetNode {
/// \brief Destructor
~TFRecordNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTFRecordNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -32,7 +32,7 @@ namespace dataset {
VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
task_(task),
usage_(usage),
@ -40,6 +40,14 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> VOCNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
return node;
}
void VOCNode::Print(std::ostream &out) const { out << Name(); }
Status VOCNode::ValidateParams() {
Path dir(dataset_dir_);

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace dataset {
class VOCNode : public DatasetNode {
class VOCNode : public MappableSourceNode {
public:
/// \brief Constructor
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
@ -37,6 +37,18 @@ class VOCNode : public DatasetNode {
/// \brief Destructor
~VOCNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kVOCNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -29,7 +29,16 @@ namespace dataset {
// Constructor for SyncWaitNode
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
: condition_name_(condition_name), callback_(callback) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> SyncWaitNode::Copy() {
auto node = std::make_shared<SyncWaitNode>(nullptr, condition_name_, callback_);
return node;
}
void SyncWaitNode::Print(std::ostream &out) const {
out << Name() + "(cond_name:" + condition_name_ + "<pyfunc>" + ")";
}
// Function to build the BarrierOp

View File

@ -36,6 +36,18 @@ class SyncWaitNode : public DatasetNode {
/// \brief Destructor
~SyncWaitNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kSyncWaitNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -27,10 +27,15 @@ namespace mindspore {
namespace dataset {
// Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { this->AddChild(child); }
std::shared_ptr<DatasetNode> TakeNode::Copy() {
auto node = std::make_shared<TakeNode>(nullptr, take_count_);
return node;
}
void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + std::to_string(take_count_) + ")"; }
// Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create

View File

@ -34,6 +34,18 @@ class TakeNode : public DatasetNode {
/// \brief Destructor
~TakeNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTakeNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

View File

@ -22,6 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
#include "utils/ms_context.h"
@ -39,7 +40,19 @@ TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue
total_batch_(total_batch),
create_data_info_queue_(create_data_info_queue),
device_id_(0) {
this->children.push_back(child);
this->AddChild(child);
}
std::shared_ptr<DatasetNode> TransferNode::Copy() {
auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_,
create_data_info_queue_);
return node;
}
void TransferNode::Print(std::ostream &out) const {
out << Name() + "(prefetch_size:" + std::to_string(prefetch_size_) +
",send_epoch_end:" + (send_epoch_end_ ? "true" : "false") + ",total_batch:" + std::to_string(total_batch_) +
")";
}
// Validator for TransferNode
@ -94,5 +107,16 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
return node_ops;
}
// Visitor accepting method for NodePass
Status TransferNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<TransferNode>(), modified);
}
// Visitor accepting method for NodePass
Status TransferNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<TransferNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -35,6 +35,18 @@ class TransferNode : public DatasetNode {
/// \brief Destructor
~TransferNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTransferNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -43,6 +55,20 @@ class TransferNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
private:
std::string queue_name_;
int32_t device_id_;

View File

@ -21,30 +21,36 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) {
for (auto dataset : datasets_) {
this->children.push_back(dataset);
}
ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) {
for (auto const &child : datasets) AddChild(child);
}
std::shared_ptr<DatasetNode> ZipNode::Copy() {
std::vector<std::shared_ptr<DatasetNode>> empty_vector;
empty_vector.clear();
auto node = std::make_shared<ZipNode>(empty_vector);
return node;
}
void ZipNode::Print(std::ostream &out) const { out << Name(); }
Status ZipNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "ZipNode: datasets to zip are not specified.";
if (children_.size() < 2) {
std::string err_msg = "ZipNode: input datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ZipNode: zip datasets should not be null.";
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
std::string err_msg = "ZipNode: input datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
@ -56,5 +62,17 @@ std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() {
return node_ops;
}
// Visitor accepting method for NodePass
Status ZipNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ZipNode>(), modified);
}
// Visitor accepting method for NodePass
Status ZipNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -34,6 +34,18 @@ class ZipNode : public DatasetNode {
/// \brief Destructor
~ZipNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kZipNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
@ -42,8 +54,17 @@ class ZipNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<DatasetNode>> datasets_;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
};
} // namespace dataset

View File

@ -22,10 +22,12 @@
#endif
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifdef ENABLE_PYTHON
@ -34,34 +36,6 @@
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#endif
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#endif
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
@ -113,7 +87,12 @@ namespace mindspore {
namespace dataset {
// Driver method for TreePass
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
if (root_ir == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
}
return this->RunOnTree(root_ir, modified);
}
// Driver method for NodePass
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
@ -132,15 +111,23 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
// Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
RETURN_IF_NOT_OK(node_ir->Accept(this, modified));
bool m = false;
RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
*modified |= m;
for (const auto &c : node_ir->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m));
*modified |= m;
}
return node_ir->AcceptAfter(this, modified);
RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m));
*modified |= m;
return Status::OK();
}
// Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
bool m = false;
// Initialize bfs queue with root
std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
bfsQueue.push(node_ir);
@ -152,7 +139,8 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
bfsQueue.pop();
// Run node pass
RETURN_IF_NOT_OK(curNode->Accept(this, modified));
RETURN_IF_NOT_OK(curNode->Accept(this, &m));
*modified |= m;
// Push children into bfs queue
for (const auto &c : curNode->Children()) {
@ -162,331 +150,119 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
return Status::OK();
}
// For datasetops IR
// For non-leaf IR node
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
// For datasetops/source IR
Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
// For leaf IR Node
Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

View File

@ -26,123 +26,87 @@
namespace mindspore {
namespace dataset {
// Non-leaf IR node
class BatchNode;
class BucketBatchByLengthNode;
#ifndef ENABLE_ANDROID
class BuildSentenceVocabNode;
#endif
class BuildVocabNode;
class ConcatNode;
class FilterNode;
class MapNode;
class ProjectNode;
class RenameNode;
class RepeatNode;
class RootNode;
class ShuffleNode;
class SkipNode;
#ifdef ENABLE_PYTHON
class SyncWaitNode;
#endif
class TakeNode;
class TransferNode;
class ZipNode;
#ifdef ENABLE_PYTHON
class SyncWaitNode;
#endif
#ifndef ENABLE_ANDROID
class BuildSentenceVocabNode;
#endif
// Leaf IR node
class AlbumNode;
class CelebANode;
class Cifar100Node;
class Cifar10Node;
#ifndef ENABLE_ANDROID
class CLUENode;
#endif
class CocoNode;
#ifndef ENABLE_ANDROID
class CSVNode;
#endif
class ImageFolderNode;
class ManifestNode;
class MnistNode;
class RandomNode;
class VOCNode;
#ifdef ENABLE_PYTHON
class GeneratorNode;
#endif
class ImageFolderNode;
class ManifestNode;
#ifndef ENABLE_ANDROID
class CLUENode;
class CSVNode;
class MindDataNode;
#endif
class MnistNode;
class RandomNode;
#ifndef ENABLE_ANDROID
class TextFileNode;
#endif
#ifndef ENABLE_ANDROID
class TFRecordNode;
#endif
class VOCNode;
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
class BatchOp;
class MapOp;
class ProjectOp;
class RenameOp;
class SkipOp;
class ShuffleOp;
class AlbumOp;
class RandomDataOp;
class RepeatOp;
class TakeOp;
class ZipOp;
class DeviceQueueOp;
class ImageFolderOp;
class MnistOp;
class ManifestOp;
class CifarOp;
class VOCOp;
class CocoOp;
class CelebAOp;
class EpochCtrlOp;
class BuildVocabOp;
class ConcatOp;
#ifndef ENABLE_ANDROID
class MindRecordOp;
class TFReaderOp;
class CacheOp;
class CacheMergeOp;
class CacheLookupOp;
class BuildSentencePieceVocabOp;
class ClueOp;
class CsvOp;
class TextFileOp;
#endif
#ifdef ENABLE_PYTHON
class FilterOp;
class GeneratorOp;
#endif
//////////////////////////////////
@ -175,6 +139,13 @@ class TreePass : public Pass {
/// \param[inout] modified Indicate if the tree was modified
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree.
@ -191,8 +162,17 @@ class TreePass : public Pass {
//////////////////////////////////
};
// NodePass is a basic Pass class which performs transformation on Node visiting.
// NodePass is a base Pass class which performs transformation on node visiting.
// NodePass implements Visitor design pattern.
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
// and the other when all the descending nodes are visited.
// Actual transformation is done by implementing a new derived class of NodePass.
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
// it wants to action on them, overriding the ones defined in NodePass.
// If the derived class wants to perform the same action on all node types,
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
// to call the Visit()/VisitAfter() in this parent NodePass class.
class NodePass : public Pass {
public:
// Tree traversal order
@ -223,153 +203,57 @@ class NodePass : public Pass {
/// \return Status The error code return
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }
// For datasetops IR
// Visit method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare Visit of its own type and override "Accept" from DatasetNode.
// Visit()/VisitAfter() method to be overridden.
// These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here.
// Their implementation are in .cc file to avoid adding the include files of those derived classes.
// The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of
// the derived classes. With this technique, the transformation classes derived from NodePass needs only to
// implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes
// of DatasetNode in the same way.
// Note that virtual template functions are not permitted in C++.
//
// Non-leaf IR node
virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified);
// VisitAfter method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode.
virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif
virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<FilterNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<RootNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);
#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified);
#endif
virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);
// For datasetops/source IR
virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified);
virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified);
virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified);
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif
virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified);
#endif
#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified);
#endif
virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified);
#endif
virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified);
#endif
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified);
#endif
virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified);
// Leaf IR node
virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified);
//////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
@ -396,86 +280,47 @@ class NodePass : public Pass {
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
// of its own type and override "Accept" from DatasetOp.
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified);
#endif
#ifdef ENABLE_PYTHON
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
#endif
//////////////////////////////////

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
@ -119,11 +120,16 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha
return Status::OK();
}
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
num_epochs_ = num_epochs;
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
optimize_ = true; // Always ON (temporary)
RETURN_UNEXPECTED_IF_NULL(root_ir);
RETURN_UNEXPECTED_IF_NULL(input_ir);
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
// Copy the input IR tree and insert under the root node
// Create a root node to host the input IR tree
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
// Pre-pass of the IR tree
RETURN_IF_NOT_OK(PrePass(root_ir));
@ -136,11 +142,15 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
// Post-pass of the IR tree
RETURN_IF_NOT_OK(PostPass(root_ir));
MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n';
// This will evolve in the long run
tree_ = std::make_unique<ExecutionTree>();
// Build the Execution tree from the child of the root node
std::shared_ptr<DatasetOp> root_op;
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
// We will replace input_ir with root_ir->Children()[0] once IR optimizer is in
RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);

View File

@ -67,10 +67,6 @@ class TreeAdapter {
// Optional optimizations status
bool OptimizationEnabled() const { return optimize_; }
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
private:
// This function runs a mandatory pass checking the syntax and semantics of the IR tree.
Status PrePass(std::shared_ptr<DatasetNode> ir);

View File

@ -47,6 +47,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> Build() = 0;
/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> Copy() = 0;
/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }
@ -132,6 +136,11 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
even_dist_);
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
@ -160,6 +169,10 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
@ -174,9 +187,8 @@ class PKSamplerObj : public SamplerObj {
class PreBuiltSamplerObj : public SamplerObj {
public:
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif
@ -188,6 +200,8 @@ class PreBuiltSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
std::shared_ptr<SamplerObj> Copy() override;
bool ValidateParams() override;
private:
@ -205,6 +219,8 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
@ -224,6 +240,10 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
@ -243,6 +263,10 @@ class SubsetRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
@ -262,6 +286,10 @@ class WeightedRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
}
bool ValidateParams() override;
private:

View File

@ -32,7 +32,10 @@ class TensorOp;
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
public:
/// \brief Constructor
TensorOperation();
TensorOperation() : random_op_(false) {}
/// \brief Constructor
explicit TensorOperation(bool random) : random_op_(random) {}
/// \brief Destructor
~TensorOperation() = default;
@ -42,6 +45,13 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
virtual std::shared_ptr<TensorOp> Build() = 0;
virtual Status ValidateParams() = 0;
/// \brief Check whether the operation is deterministic.
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
bool IsRandomOp() const { return random_op_; }
protected:
bool random_op_;
};
// Helper function to validate fill value

View File

@ -427,7 +427,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
std::vector<dsize_t> cur_ind, size_t cur_dim) {
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
dst->CopyLastDimAt(src, cur_ind);
RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind));
} else { // not the last dimension, keep doing recursion
dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
for (dsize_t i = 0; i < min_ind; i++) {

View File

@ -57,7 +57,7 @@ class RandomCropOp : public TensorOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
// Function breaks out the compute function's image padding functionality and makes available to other Ops
// Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op
// Using this class as a base - re-structured to allow for RandomCropWithBBox Augmentation Op
// @param input: Input is the original Image
// @param pad_image: Pointer to new Padded image
// @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required

View File

@ -570,7 +570,7 @@ class WeightedRandomSampler(BuiltinSampler):
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
Args:
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
replacement (bool): If True, put the sample ID back for the next draw (default=True).

View File

@ -17,6 +17,7 @@ SET(DE_UT_SRCS
c_api_dataset_coco_test.cc
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_ir_node_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc

View File

@ -0,0 +1,142 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::MsLogLevel::INFO;
class MindDataTestIRNodes : public UT::DatasetOpTesting {
public:
MindDataTestIRNodes() = default;
void SetUp() override { GlobalInit(); }
// compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code
// if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same)
Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) {
CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr.");
if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) {
std::string err_msg =
"Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!";
RETURN_STATUS_UNEXPECTED(err_msg);
}
size_t num_child = root1->Children().size();
CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(),
root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " +
std::to_string(root2->Children().size()) + " child.");
for (size_t ind = 0; ind < num_child; ind++) {
RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag));
}
return Status::OK();
}
// print the node's name in post order
Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) {
RETURN_UNEXPECTED_IF_NULL(ir);
for (auto child : ir->Children()) {
RETURN_IF_NOT_OK(PostOrderPrintTree(child, names));
}
names += (ir->Name() + "->");
return Status::OK();
}
};
TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy.";
auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode();
auto tree2 = tree1->DeepCopy();
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
// expected output for the 2 names:
// RandomDataset->Repeat->Project->Shuffle->Batch->
EXPECT_EQ(tree_1_names, tree_2_names);
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
// verify compare function is correct
EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError());
}
TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy.";
auto branch1 = RandomData(44)->Project({"label"});
auto branch2 = RandomData(44)->Shuffle(10);
auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode();
auto tree2 = tree1->DeepCopy();
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
// expected output for the 2 names:
// RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch->
EXPECT_EQ(tree_1_names, tree_2_names);
// verify the pointer within the same tree are the same
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
// verify two trees
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
}
TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove.";
auto branch1 = RandomData(44)->Project({"label"});
auto branch2 = ImageFolder("path");
auto tree = Zip({branch1, branch2})->IRNode();
/***
tree looks like this, we will remove node and test its functionalities
Zip
/ \
Project ImageFolder
/
RandomData
***/
auto tree_copy_1 = tree->DeepCopy();
ASSERT_EQ(tree_copy_1->Children().size(), 2);
// remove the project in the tree and test
ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree
ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false));
// remove the ImageFolder, a leaf node from the tree
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names));
EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->");
auto tree_copy_2 = tree->DeepCopy();
ASSERT_EQ(tree_copy_2->Children().size(), 2);
tree_copy_2->Children()[1]->Remove();
ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names));
EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->");
}