forked from mindspore-Ecosystem/mindspore
[MD] skip node pushdown optimization pass for Reset
This commit is contained in:
parent
2655d64720
commit
260cebf650
|
@ -7,6 +7,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
|
|||
random_sampler.cc
|
||||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
skip_first_epoch_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
subset_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -66,7 +66,7 @@ class SequentialSamplerRT : public SamplerRT {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
int64_t current_id_; // The id sequencer. Each new id increments from this
|
||||
int64_t start_index_; // The starting id. current_id_ begins from here.
|
||||
int64_t id_count_; // An internal counter that tracks how many ids have been produced
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status SkipFirstEpochSamplerRT::ResetSampler() {
|
||||
if (id_count_ != num_samples_) {
|
||||
std::string err_msg =
|
||||
"[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) +
|
||||
" num_samples_: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
current_id_ = 0;
|
||||
id_count_ = 0;
|
||||
|
||||
if (!first_epoch_done_) {
|
||||
num_samples_ += start_index_;
|
||||
start_index_ = 0;
|
||||
samples_per_tensor_ = num_samples_;
|
||||
first_epoch_done_ = true;
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t SkipFirstEpochSamplerRT::CalculateNumSamples(const int64_t num_rows) { return -1; }
|
||||
|
||||
void SkipFirstEpochSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: SkipFirstEpochSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
SamplerRT::SamplerPrint(out, show_all);
|
||||
// Then add our own info
|
||||
out << "\nStart index: " << start_index_;
|
||||
out << "\nFirst epoch done: " << first_epoch_done_;
|
||||
out << "\nCurrent id: " << current_id_;
|
||||
out << "\nid count:" << id_count_;
|
||||
}
|
||||
}
|
||||
|
||||
Status SkipFirstEpochSamplerRT::to_json(nlohmann::json *out_json) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_json);
|
||||
nlohmann::json args;
|
||||
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
|
||||
args["sampler_name"] = "SkipFirstEpochSampler";
|
||||
args["start_index"] = start_index_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class SkipFirstEpochSamplerRT : public SequentialSamplerRT {
|
||||
public:
|
||||
// Constructor
|
||||
using SequentialSamplerRT::SequentialSamplerRT;
|
||||
|
||||
// Destructor.
|
||||
~SkipFirstEpochSamplerRT() = default;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return Status The status code returned
|
||||
Status ResetSampler() override;
|
||||
|
||||
/// \brief Gets the number of samples available
|
||||
/// \note Since this sampler returns different number of samples in the first epoch (compared to other epochs), this
|
||||
/// function always returns -1
|
||||
/// \param[in] num_rows The total number of rows in the dataset
|
||||
/// \return int64_t Calculated number of samples (always -1)
|
||||
int64_t CalculateNumSamples(int64_t num_rows) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
void SamplerPrint(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
bool first_epoch_done_ = false;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,8 +21,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
|
@ -73,5 +74,18 @@ Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNo
|
|||
*result = std::make_shared<ProjectNode>(ds, columns);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status ProjectNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<ProjectNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status ProjectNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<ProjectNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -70,6 +70,18 @@ class ProjectNode : public DatasetNode {
|
|||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,8 +21,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/rename_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
|
@ -82,5 +83,17 @@ Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNod
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status RenameNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<RenameNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status RenameNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<RenameNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,6 +38,18 @@ class RenameNode : public DatasetNode {
|
|||
/// \return Name of the current node
|
||||
std::string Name() const override { return kRenameNode; }
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -99,6 +99,9 @@ class SkipNode : public DatasetNode {
|
|||
|
||||
void SetFirstEpochOnly(bool flag) { first_epoch_only_ = flag; }
|
||||
|
||||
/// \brief Getter functions
|
||||
const bool FirstEpochOnly() const { return first_epoch_only_; }
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
bool first_epoch_only_ = false;
|
||||
|
|
|
@ -8,6 +8,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES
|
|||
random_sampler_ir.cc
|
||||
samplers_ir.cc
|
||||
sequential_sampler_ir.cc
|
||||
skip_first_epoch_sampler_ir.cc
|
||||
subset_random_sampler_ir.cc
|
||||
subset_sampler_ir.cc
|
||||
weighted_random_sampler_ir.cc
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -65,7 +65,7 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
int64_t start_index_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h"
|
||||
#include "minddata/dataset/util/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor
|
||||
SkipFirstEpochSamplerObj::SkipFirstEpochSamplerObj(int64_t start_index) : SequentialSamplerObj(start_index, 0) {}
|
||||
|
||||
// Destructor
|
||||
SkipFirstEpochSamplerObj::~SkipFirstEpochSamplerObj() = default;
|
||||
|
||||
Status SkipFirstEpochSamplerObj::to_json(nlohmann::json *const out_json) {
|
||||
nlohmann::json args;
|
||||
RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
|
||||
args["sampler_name"] = "SkipFirstEpochSamplerObj";
|
||||
args["start_index"] = start_index_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status SkipFirstEpochSamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler) {
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "start_index", "SkipFirstEpochSamplerObj"));
|
||||
int64_t start_index = json_obj["start_index"];
|
||||
*sampler = std::make_shared<SkipFirstEpochSamplerObj>(start_index);
|
||||
// Run common code in super class to add children samplers
|
||||
RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status SkipFirstEpochSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
|
||||
// runtime sampler object
|
||||
*sampler = std::make_shared<dataset::SkipFirstEpochSamplerRT>(start_index_, 0);
|
||||
Status s = BuildChildren(sampler);
|
||||
sampler = s.IsOk() ? sampler : nullptr;
|
||||
return s;
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerObj> SkipFirstEpochSamplerObj::SamplerCopy() {
|
||||
auto sampler = std::make_shared<SkipFirstEpochSamplerObj>(start_index_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Internal Sampler class forward declaration
|
||||
class SamplerRT;
|
||||
|
||||
class SkipFirstEpochSamplerObj : public SequentialSamplerObj {
|
||||
public:
|
||||
explicit SkipFirstEpochSamplerObj(int64_t start_index);
|
||||
|
||||
~SkipFirstEpochSamplerObj() override;
|
||||
|
||||
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override;
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function for read sampler from JSON object
|
||||
/// \param[in] json_obj JSON object to be read
|
||||
/// \param[out] sampler Sampler constructed from parameters in JSON object
|
||||
/// \return Status of the function
|
||||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler);
|
||||
#endif
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_
|
|
@ -15,6 +15,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES
|
|||
pre/input_validation_pass.cc
|
||||
pre/node_offload_pass.cc
|
||||
pre/node_removal_pass.cc
|
||||
pre/skip_pushdown_pass.cc
|
||||
)
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SkipPushdownPass::SkipNodes::SkipNodes() : skip_count_(0) {}
|
||||
|
||||
// activate the optimization steps, and increase skip_count_ (if not the first skip node in the pipeline)
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
|
||||
if (node->FirstEpochOnly() == false) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
skip_count_ += node->Count();
|
||||
nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) {
|
||||
if (node->FirstEpochOnly() == false) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ == 0, "The skip_count_ cannot be non-zero here.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (node->BatchSizeFunc()) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
// we have an active skip node above.
|
||||
skip_count_ *= node->BatchSize();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
// we have an active skip node above.
|
||||
auto new_sampler = std::make_shared<SkipFirstEpochSamplerObj>(skip_count_);
|
||||
MS_LOG(INFO) << "Adding SkipFirstEpochSampler(" << skip_count_ << ")";
|
||||
auto sampler = node->Sampler();
|
||||
if (sampler != nullptr) {
|
||||
new_sampler->AddChildSampler(sampler);
|
||||
}
|
||||
node->SetSampler(new_sampler);
|
||||
skip_count_ = 0;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
// we have an active skip node above.
|
||||
MS_LOG(WARNING)
|
||||
<< "Pushing down skip node below a map node will result in slightly different outputs for random transformations.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
// insert a skip node above
|
||||
insert_skip_above_.emplace_back(node, skip_count_);
|
||||
skip_count_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Since MindDataset requires its own SkipFirstEpochSampler (which is not implemented) we insert the skip node above it.
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
// insert a skip node above
|
||||
insert_skip_above_.emplace_back(node, skip_count_);
|
||||
skip_count_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// This functions is used for Ops that are random, and the ones in which Visit is Not Implemented yet;
|
||||
Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative.");
|
||||
if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow
|
||||
|
||||
// insert a skip node above
|
||||
insert_skip_above_.emplace_back(node, skip_count_);
|
||||
skip_count_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
SkipPushdownPass::SkipPushdownPass() {}
|
||||
|
||||
// Walk the tree to push down the skip node inserted when Reset is called.
|
||||
Status SkipPushdownPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
|
||||
MS_LOG(INFO) << "Pre pass: skip node pushdown pass started.";
|
||||
// Assumption: The total skip counts in the first_epoch_only skip node is less than the size of the dataset. This
|
||||
// assumption is not validated here.
|
||||
// Create the skip node pass which can identify which nodes need to be removed and which ones added.
|
||||
std::unique_ptr<SkipPushdownPass::SkipNodes> skip_nodes = std::make_unique<SkipPushdownPass::SkipNodes>();
|
||||
if (root_ir->IsSizeDefined()) {
|
||||
RETURN_IF_NOT_OK(skip_nodes->Run(root_ir, modified));
|
||||
}
|
||||
|
||||
// Update modified flag if there were any nodes identified to be removed
|
||||
if (skip_nodes->nodes_to_remove().empty() == false || skip_nodes->insert_skip_above().empty() == false) {
|
||||
*modified = true;
|
||||
}
|
||||
|
||||
// Add skip node(s) to the tree (if any)
|
||||
for (auto iter : skip_nodes->insert_skip_above()) {
|
||||
MS_LOG(INFO) << "Inserting a Skip(" << iter.second << ") node above this node: " << iter.first->Name();
|
||||
auto new_skip_node = std::make_shared<SkipNode>(iter.second);
|
||||
new_skip_node->SetFirstEpochOnly(true);
|
||||
RETURN_IF_NOT_OK(iter.first->InsertAbove(new_skip_node));
|
||||
}
|
||||
|
||||
// Then, execute the removal of any nodes that were set up for removal
|
||||
for (auto node : skip_nodes->nodes_to_remove()) {
|
||||
RETURN_IF_NOT_OK(node->Drop());
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Pre pass: skip node pushdown pass is complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class BatchNode;
|
||||
class DatasetNode;
|
||||
class DatasetOp;
|
||||
class MappableSourceNode;
|
||||
class MapNode;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class MindDataNode;
|
||||
#endif
|
||||
class NonMappableSourceNode;
|
||||
class ProjectNode;
|
||||
class RenameNode;
|
||||
class SkipNode;
|
||||
|
||||
/// \class SkipPushdownPass skip_pushdown_pass.h
|
||||
/// \brief This is a tree pass that will push down a skip node. It uses SkipNodes to first identify if we have a skip
|
||||
/// node, and then based on the node types we observe in the tree, decide where to place the skip node (or use a
|
||||
/// SequentialSampler for MappableSource nodes).
|
||||
class SkipPushdownPass : public IRTreePass {
|
||||
/// \class SkipNodes
|
||||
/// \brief This is a NodePass whose job is to handle different nodes accordingly.
|
||||
/// It works in conjunction with the SkipPushdownPass.
|
||||
class SkipNodes : public IRNodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SkipNodes();
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipNodes() = default;
|
||||
|
||||
/// \brief Perform skip node pushdown initiation check on a SkipNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<SkipNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown completion check on a SkipNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown check on a BatchNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<BatchNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown check on a ProjectNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown check on a RenameNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<RenameNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown check on a MappableSourceNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown check on a MapNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown check on a NonMappableSourceNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Perform skip node pushdown check on a MindDataNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform skip node pushdown check on a DatasetNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform skip node pushdown completion check on a DatasetNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override { return Status::OK(); };
|
||||
|
||||
/// \brief Getter
|
||||
/// \return All the nodes where a skip node needs to be inserted above (and the skip count).
|
||||
const std::vector<std::pair<std::shared_ptr<DatasetNode>, int64_t>> &insert_skip_above() const {
|
||||
return insert_skip_above_;
|
||||
}
|
||||
|
||||
/// \brief Getter
|
||||
/// \return All the nodes to be removed
|
||||
const std::vector<std::shared_ptr<DatasetNode>> &nodes_to_remove() const { return nodes_to_remove_; }
|
||||
|
||||
private:
|
||||
std::vector<std::pair<std::shared_ptr<DatasetNode>, int64_t>> insert_skip_above_;
|
||||
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_;
|
||||
int64_t skip_count_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SkipPushdownPass();
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipPushdownPass() = default;
|
||||
|
||||
/// \brief Runs a skip_pushdown pass to push down the skip node found in the tree (for Reset scenario).
|
||||
/// \param[in, out] tree The tree to operate on.
|
||||
/// \param[in, out] Indicate of the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_
|
|
@ -36,6 +36,7 @@
|
|||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -57,8 +58,11 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
|||
MS_LOG(INFO) << "Running pre pass loops.";
|
||||
actions.emplace_back(std::make_unique<InputValidationPass>());
|
||||
actions.emplace_back(std::make_unique<CacheValidationPass>());
|
||||
if (usage_ == kDeReset) {
|
||||
actions.emplace_back(std::make_unique<AddSkipPass>());
|
||||
actions.emplace_back(std::make_unique<SkipPushdownPass>());
|
||||
}
|
||||
actions.emplace_back(std::make_unique<NodeRemovalPass>());
|
||||
if (usage_ == kDeReset) actions.emplace_back(std::make_unique<AddSkipPass>());
|
||||
actions.emplace_back(std::make_unique<EpochCtrlPass>());
|
||||
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
|
|
@ -153,6 +153,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/random_sampler_ir.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/sequential_sampler_ir.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/subset_random_sampler_ir.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/subset_sampler_ir.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.cc
|
||||
|
@ -179,6 +180,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/opt/pre/node_removal_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/epoch_ctrl_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/deep_copy_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/skip_pushdown_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/post/auto_worker_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pass.cc
|
||||
${MINDDATA_DIR}/engine/perf/auto_tune.cc
|
||||
|
@ -193,6 +195,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/datasetops/source/sampler/pk_sampler.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/source/sampler/random_sampler.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/source/sampler/sequential_sampler.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/source/sampler/subset_random_sampler.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/source/sampler/weighted_random_sampler.cc
|
||||
${MINDDATA_DIR}/engine/runtime_context.cc
|
||||
|
|
|
@ -80,7 +80,6 @@ except ModuleNotFoundError:
|
|||
if platform.system().lower() == "darwin" and multiprocessing.get_start_method() != "fork":
|
||||
multiprocessing.set_start_method("fork", True)
|
||||
|
||||
|
||||
OffloadToManualOffloadMode = {
|
||||
None: cde.ManualOffloadMode.UNSPECIFIED,
|
||||
False: cde.ManualOffloadMode.DISABLED,
|
||||
|
@ -120,7 +119,7 @@ def _reset_training_dataset(step):
|
|||
"""
|
||||
dataset = _get_training_dataset()
|
||||
if dataset is not None:
|
||||
dataset.reset(step)
|
||||
dataset._reset(step) # pylint: disable=W0212
|
||||
else:
|
||||
raise RuntimeError("Training dataset is not set.")
|
||||
|
||||
|
@ -3528,7 +3527,7 @@ class _ToDevice:
|
|||
def send(self):
|
||||
self._to_device.Send()
|
||||
|
||||
def reset(self, step):
|
||||
def _reset(self, step):
|
||||
self._to_device.Reset(step)
|
||||
|
||||
def stop_send(self):
|
||||
|
@ -3638,10 +3637,10 @@ class TransferDataset(Dataset):
|
|||
if self._to_device is not None:
|
||||
self._to_device.continue_send()
|
||||
|
||||
def reset(self, step):
|
||||
def _reset(self, step):
|
||||
if self._to_device is not None:
|
||||
logger.info("Reset the dataset pipeline to step " + str(step))
|
||||
self._to_device.reset(step)
|
||||
self._to_device._reset(step) # pylint: disable=W0212
|
||||
|
||||
def get_data_info(self):
|
||||
"""
|
||||
|
|
|
@ -178,7 +178,7 @@ class Iterator:
|
|||
self._getters()
|
||||
return self._col_names
|
||||
|
||||
def reset(self, step):
|
||||
def _reset(self, step):
|
||||
"""
|
||||
Reset the iterator to the given step number.
|
||||
|
||||
|
|
|
@ -330,9 +330,9 @@ class DatasetHelper:
|
|||
"""Continue to send data to device at the beginning of epoch."""
|
||||
self.iter.continue_send()
|
||||
|
||||
def reset(self, step):
|
||||
def _reset(self, step):
|
||||
"""Reset the dataset to the provided step."""
|
||||
self.iter.reset(step)
|
||||
self.iter._reset(step) # pylint: disable=W0212
|
||||
|
||||
def get_data_info(self):
|
||||
"""
|
||||
|
@ -387,10 +387,11 @@ class _DatasetIter:
|
|||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||
self.release = dataset.__transfer_dataset__.release
|
||||
self.continue_send = dataset.__transfer_dataset__.continue_send
|
||||
self.reset = dataset.__transfer_dataset__.reset
|
||||
self.get_data_info = dataset.__transfer_dataset__.get_data_info
|
||||
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
if hasattr(dataset.__transfer_dataset__, "_reset"):
|
||||
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212
|
||||
|
||||
def __iter__(self):
|
||||
self.index = 0
|
||||
|
|
|
@ -162,6 +162,7 @@ SET(DE_UT_SRCS
|
|||
rgba_to_bgr_op_test.cc
|
||||
rgba_to_rgb_op_test.cc
|
||||
schema_test.cc
|
||||
skip_pushdown_optimization_pass_test.cc
|
||||
slice_op_test.cc
|
||||
sliding_window_op_test.cc
|
||||
solarize_op_test.cc
|
||||
|
|
|
@ -0,0 +1,676 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
|
||||
#include "minddata/dataset/include/dataset/samplers.h"
|
||||
#include "minddata/dataset/include/dataset/vision.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataSkipPushdownTestOptimizationPass : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
MindDataSkipPushdownTestOptimizationPass() {}
|
||||
|
||||
/// \brief Compile and compare two datasets
|
||||
/// \param[in] root_original Original dataset to be added the skip step
|
||||
/// \param[in] root_target Target dataset for compare
|
||||
/// \param[in] step Skip step
|
||||
/// \return Status of the function
|
||||
Status prepare_trees(std::shared_ptr<Dataset> root_original, std::shared_ptr<Dataset> root_target, int64_t step = 0) {
|
||||
auto ir_tree = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
|
||||
// Compile adds a new RootNode to the top of the tree
|
||||
RETURN_IF_NOT_OK(ir_tree->Compile(root_original->IRNode(), 1, step));
|
||||
|
||||
auto ir_tree_target = std::make_shared<TreeAdapter>();
|
||||
// Compile adds a new RootNode to the top of the tree
|
||||
RETURN_IF_NOT_OK(ir_tree_target->Compile(root_target->IRNode(), 1,
|
||||
0)); // Step is 0 for target node tree
|
||||
|
||||
if (step != 0) {
|
||||
RETURN_IF_NOT_OK(compare_pass(ir_tree_target->RootIRNode(), ir_tree->RootIRNode()));
|
||||
}
|
||||
RETURN_IF_NOT_OK(compare_pass_row(ir_tree_target, ir_tree));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Compare two dataset node trees
|
||||
/// \param[in] expect Expected node tree for compare
|
||||
/// \param[in] root Root node tree for compare
|
||||
/// \return Status of the function
|
||||
Status compare_pass(std::shared_ptr<DatasetNode> expect, std::shared_ptr<DatasetNode> root) {
|
||||
if (expect->Children().size() == root->Children().size() && expect->Children().size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (expect->Children().size() == root->Children().size() && expect->Children().size() != 0) {
|
||||
for (int i = 0; i < expect->Children().size(); i++) {
|
||||
std::string expect_name = expect->Children()[i]->Name();
|
||||
std::string root_name = root->Children()[i]->Name();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(expect_name == root_name,
|
||||
"Expect child is " + expect_name + ", but got " + root_name);
|
||||
RETURN_IF_NOT_OK(compare_pass(expect->Children()[i], root->Children()[i]));
|
||||
}
|
||||
} else {
|
||||
return Status(StatusCode::kMDUnexpectedError, "Skip Optimization is not working as expected, expect to have " +
|
||||
std::to_string(expect->Children().size()) +
|
||||
" operation, but got " + std::to_string(root->Children().size()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Compare each row of two dataset node trees
|
||||
/// \param[in] expect Expected tree for compare
|
||||
/// \param[in] root Root tree for compare
|
||||
/// \return Status of the function
|
||||
Status compare_pass_row(std::shared_ptr<TreeAdapter> expect, std::shared_ptr<TreeAdapter> root) {
|
||||
TensorRow row_expect;
|
||||
TensorRow row_root;
|
||||
RETURN_IF_NOT_OK(expect->GetNext(&row_expect));
|
||||
RETURN_IF_NOT_OK(root->GetNext(&row_root));
|
||||
while (row_expect.size() != 0 && row_root.size() != 0) {
|
||||
std::vector<std::shared_ptr<Tensor>> e = row_expect.getRow();
|
||||
std::vector<std::shared_ptr<Tensor>> r = row_root.getRow();
|
||||
for (int i = 0; i < e.size(); i++) {
|
||||
nlohmann::json out_json;
|
||||
RETURN_IF_NOT_OK(e[i]->to_json(&out_json));
|
||||
std::stringstream json_ss;
|
||||
json_ss << out_json;
|
||||
|
||||
nlohmann::json out_json1;
|
||||
RETURN_IF_NOT_OK(r[i]->to_json(&out_json1));
|
||||
std::stringstream json_ss1;
|
||||
json_ss1 << out_json1;
|
||||
EXPECT_EQ(json_ss.str(), json_ss1.str());
|
||||
}
|
||||
RETURN_IF_NOT_OK(expect->GetNext(&row_expect));
|
||||
RETURN_IF_NOT_OK(root->GetNext(&row_root));
|
||||
}
|
||||
EXPECT_EQ(row_expect.size(), row_root.size());
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
TensorRow VecToRow(const MSTensorVec &v);
|
||||
|
||||
MSTensorVec RowToVec(const TensorRow &v);
|
||||
|
||||
MSTensorVec Predicate1(MSTensorVec in);
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Sampler in MappableSourceNode
|
||||
/// Expectation: Skip node is pushed down and removed after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownMappableSourceNode) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownMappableSourceNode.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
auto root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>());
|
||||
auto root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(2));
|
||||
EXPECT_OK(prepare_trees(root, root_target, 2));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch Operation
|
||||
/// Expectation: Skip node is pushed down and removed after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownBatch) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownBatch.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Batch(5)->Skip(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(25))->Batch(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Batch(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(25))->Batch(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 5));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Rename Operation
|
||||
/// Expectation: Skip node is pushed down and removed after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRename) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRename.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Rename({"label"}, {"fake_label"})->Skip(5);
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(5))->Rename({"label"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Rename({"label"}, {"fake_label"});
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(5))->Rename({"label"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 5));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Project Operation
|
||||
/// Expectation: Skip node is pushed down and removed for Project after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownProject) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownProject.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Skip(10);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(10))->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"});
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(10))->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 10));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Skip(1)
|
||||
->Project({"label", "image"})
|
||||
->Skip(10);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(11))->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Concat Operation
|
||||
/// Expectation: Skip node cannot be pushed down for Concat after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownConcat) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownConcat.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
std::vector<std::shared_ptr<Dataset>> datasets = {
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())};
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets)->Skip(10);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets)->Skip(10);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Concat(datasets)->Skip(10);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 10));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Zip Operation
|
||||
/// Expectation: Skip node cannot be pushed down for Zip after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownZip) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownZip.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
std::vector<std::shared_ptr<Dataset>> datasets = {
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label"})};
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets);
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(10);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 10));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Repeat Operation
|
||||
/// Expectation: Skip operation cannot be pushed down for Repeat after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRepeat) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRepeat.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(5)->Skip(11);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 11));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Take Operation
|
||||
/// Expectation: Skip operation cannot be pushed down for Take after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownTake) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownTake.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(20);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(20)->Skip(10);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 10));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Skip Operation
|
||||
/// Expectation: Skip node cannot be pushed down after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSkip) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSkip.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Skip(3);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Skip(3)->Skip(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 5));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with NonMappableSourceNode(CSV)
|
||||
/// Expectation: Skip node is pushed down after optimization pass, but cannot be removed
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownNonMappableSourceNode) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownNonMappableSourceNode.";
|
||||
std::string folder_path = datasets_root_path_ + "/testCSV/append.csv";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"});
|
||||
root_target =
|
||||
CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"})->Skip(1);
|
||||
root_target =
|
||||
CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch and Rename Operations
|
||||
/// Expectation: Skip node is pushed down after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations1) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations1.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Batch(5)
|
||||
->Skip(2)
|
||||
->Rename({"label"}, {"fake_label"});
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Batch(5)
|
||||
->Skip(2)
|
||||
->Skip(2)
|
||||
->Rename({"label"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 2));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch and Skip Operations and Sampler
|
||||
/// Expectation: Skip node is pushed down after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations2) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations2.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Batch(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(2)->Skip(10)->Batch(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 2));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Take and Project Operations
|
||||
/// Expectation: Skip node is pushed down for Project but not for Take
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations3) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations3.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(20)->Project({"label", "image"});
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Take(20)
|
||||
->Skip(2)
|
||||
->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 2));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with no Skip/ Skip(0) Operation
|
||||
/// Expectation: Skip(0) shows the same result as no Skip operation
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSkip0) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSkip0.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Take(5);
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Take(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"})->Skip(0);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Skip(0)
|
||||
->Project({"label", "image"})
|
||||
->Skip(0);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(0)->Project({"label", "image"});
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Skip(2)
|
||||
->Skip(1)
|
||||
->Project({"label", "image"});
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Skip(2)
|
||||
->Skip(1)
|
||||
->Skip(1)
|
||||
->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Repeat Operation
|
||||
/// Expectation: Skip operation cannot be pushed down for Repeat operation
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRepeat2) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRepeat2.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3)->Skip(1);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3)->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Repeat(3)->Skip(50);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 50));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Repeat(3);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Repeat(3)->Skip(50);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 50));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Take Operation
|
||||
/// Expectation: Skip operation cannot be pushed down for Take operation
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownTake2) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownTake2.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(3);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(3);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(3);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Take(3)->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Take(5)->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Concat/Zip Operation
|
||||
/// Expectation: Skip node cannot be removed for Concat/Zip after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownUnsupported) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownUnsupported.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
std::vector<std::shared_ptr<Dataset>> datasets = {
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label"})};
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"label"})->Concat(datasets);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Project({"label"})
|
||||
->Concat(datasets)
|
||||
->Skip(2);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 2));
|
||||
|
||||
root =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(1);
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets);
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Project({"image"})->Zip(datasets)->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Filter Operation
|
||||
/// Expectation: Skip node cannot be pushed down for Filter after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownUnsupported_Filter) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownUnsupported_Filter.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Filter(Predicate1, {"label"});
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Filter(Predicate1, {"label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Filter(Predicate1, {"label"})->Skip(1);
|
||||
root_target =
|
||||
ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Filter(Predicate1, {"label"})->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Filter(Predicate1, {"label"});
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())
|
||||
->Skip(1)
|
||||
->Filter(Predicate1, {"label"})
|
||||
->Skip(1);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with SubsetSampler as child
|
||||
/// Expectation: Skip node is removed for Rename/Project/Map after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSubsetSampler) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSubsetSampler.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
std::vector<int64_t> indices = {0, 1, 2, 3, 4, 5};
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 3))->Skip(1);
|
||||
auto sampler = std::make_shared<SequentialSampler>(1);
|
||||
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 3));
|
||||
root_target = ImageFolder(folder_path, false, sampler);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 4))
|
||||
->Rename({"label"}, {"fake_label"});
|
||||
sampler = std::make_shared<SequentialSampler>(1);
|
||||
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 4));
|
||||
root_target = ImageFolder(folder_path, false, sampler)->Rename({"label"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root =
|
||||
ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 10))->Project({"label", "image"});
|
||||
sampler = std::make_shared<SequentialSampler>(1);
|
||||
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 10));
|
||||
root_target = ImageFolder(folder_path, false, sampler)->Project({"label", "image"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
std::vector<std::shared_ptr<TensorTransform>> transforms;
|
||||
std::vector<int32_t> size = {80, 80};
|
||||
std::vector<uint32_t> ignore = {20, 20, 20, 20};
|
||||
std::shared_ptr<TensorTransform> operation1 = std::make_shared<vision::AutoContrast>(0.5, ignore);
|
||||
std::shared_ptr<TensorTransform> operation2 = std::make_shared<vision::CenterCrop>(size);
|
||||
transforms.push_back(operation1);
|
||||
transforms.push_back(operation2);
|
||||
root = ImageFolder(folder_path, true, std::make_shared<SubsetRandomSampler>(indices, 3))->Map(transforms);
|
||||
sampler = std::make_shared<SequentialSampler>(1);
|
||||
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 3));
|
||||
root_target = ImageFolder(folder_path, true, sampler)->Map(transforms);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Batch Operation
|
||||
/// Expectation: Skip node is pushed down for Batch after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownBatch2) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownBatch2.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Batch(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(1)->Batch(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(3)->Batch(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Skip(3)->Skip(15)->Batch(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 3));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>())->Batch(5);
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(15))->Batch(5);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 3));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with SkipPushdownNonMappableSourceNode(CSV)
|
||||
/// Expectation: Skip node is pushed down for Rename/Batch/Project after optimization pass, but cannot be removed for
|
||||
/// NonMappableSourceNode(CSV)
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownNonMappableSourceNode2) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownNonMappableSourceNode2.";
|
||||
std::string folder_path = datasets_root_path_ + "/testCSV/append.csv";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"});
|
||||
root_target =
|
||||
CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"});
|
||||
root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
|
||||
->Skip(1)
|
||||
->Skip(1)
|
||||
->Rename({"col1"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
|
||||
->Repeat(5)
|
||||
->Skip(1)
|
||||
->Batch(2)
|
||||
->Project({"col1", "col2"})
|
||||
->Rename({"col1"}, {"fake_label"});
|
||||
root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
|
||||
->Repeat(5)
|
||||
->Skip(1)
|
||||
->Skip(2)
|
||||
->Batch(2)
|
||||
->Project({"col1", "col2"})
|
||||
->Rename({"col1"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
}
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with combined Operations
|
||||
/// Expectation: Skip node is removed/reduced for Rename/Batch after optimization pass
|
||||
TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations4) {
|
||||
MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations4.";
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
||||
std::shared_ptr<Dataset> root;
|
||||
std::shared_ptr<Dataset> root_target;
|
||||
|
||||
std::vector<int64_t> indices = {0, 1, 2, 3, 4, 5};
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 6))
|
||||
->Repeat(4)
|
||||
->Skip(4)
|
||||
->Take(10)
|
||||
->Skip(2)
|
||||
->Take(2)
|
||||
->Rename({"label"}, {"fake_label"});
|
||||
auto sampler = std::make_shared<SequentialSampler>(1);
|
||||
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 6));
|
||||
root_target = ImageFolder(folder_path, false, sampler)
|
||||
->Repeat(4)
|
||||
->Skip(4)
|
||||
->Take(10)
|
||||
->Skip(2)
|
||||
->Take(2)
|
||||
->Skip(1)
|
||||
->Rename({"label"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 1));
|
||||
|
||||
root = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 5))
|
||||
->Repeat(8)
|
||||
->Skip(4)
|
||||
->Skip(2)
|
||||
->Take(10)
|
||||
->Rename({"label"}, {"fake_label"})
|
||||
->Batch(4);
|
||||
sampler = std::make_shared<SequentialSampler>(1);
|
||||
sampler->AddChild(std::make_shared<SubsetRandomSampler>(indices, 5));
|
||||
root_target = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>(indices, 5))
|
||||
->Repeat(8)
|
||||
->Skip(4)
|
||||
->Skip(2)
|
||||
->Take(10)
|
||||
->Skip(8)
|
||||
->Rename({"label"}, {"fake_label"})
|
||||
->Batch(4);
|
||||
EXPECT_OK(prepare_trees(root, root_target, 2));
|
||||
|
||||
folder_path = datasets_root_path_ + "/testCSV/append.csv";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
|
||||
root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
|
||||
->Skip(1)
|
||||
->Repeat(4)
|
||||
->Batch(3)
|
||||
->Skip(1)
|
||||
->Rename({"col1"}, {"fake_label"});
|
||||
root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)
|
||||
->Skip(1)
|
||||
->Repeat(4)
|
||||
->Skip(3)
|
||||
->Batch(3)
|
||||
->Rename({"col1"}, {"fake_label"});
|
||||
EXPECT_OK(prepare_trees(root, root_target, 0));
|
||||
}
|
Loading…
Reference in New Issue