forked from mindspore-Ecosystem/mindspore
port over getter pass
port over tensor_op_fusion pass add fusion support for prebuildOpeartion (TensorOpeartion)
This commit is contained in:
parent
6d6ed86a9d
commit
0e68575e77
|
@ -227,6 +227,8 @@ Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
|
std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
|
||||||
|
|
||||||
|
std::string PreBuiltOperation::Name() const { return op_ ? op_->Name() : kPreBuiltOperation; }
|
||||||
|
|
||||||
// RandomApplyOperation
|
// RandomApplyOperation
|
||||||
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
|
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
|
||||||
: TensorOperation(true), transforms_(transforms), prob_(prob) {}
|
: TensorOperation(true), transforms_(transforms), prob_(prob) {}
|
||||||
|
|
|
@ -1264,72 +1264,7 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
|
||||||
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale,
|
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale,
|
||||||
std::vector<float> ratio,
|
std::vector<float> ratio,
|
||||||
InterpolationMode interpolation, int32_t max_attempts)
|
InterpolationMode interpolation, int32_t max_attempts)
|
||||||
: TensorOperation(true),
|
: RandomResizedCropOperation(size, scale, ratio, interpolation, max_attempts) {}
|
||||||
size_(size),
|
|
||||||
scale_(scale),
|
|
||||||
ratio_(ratio),
|
|
||||||
interpolation_(interpolation),
|
|
||||||
max_attempts_(max_attempts) {}
|
|
||||||
|
|
||||||
Status RandomCropDecodeResizeOperation::ValidateParams() {
|
|
||||||
// size
|
|
||||||
if (size_.empty() || size_.size() > 2) {
|
|
||||||
std::string err_msg = "RandomCropDecodeResize: size vector has incorrect size: " + std::to_string(size_.size());
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
RETURN_IF_NOT_OK(ValidateVectorPositive("RandomCropDecodeResize", size_));
|
|
||||||
// rescale
|
|
||||||
if (scale_.empty() || scale_.size() != 2) {
|
|
||||||
std::string err_msg = "RandomCropDecodeResize: scale vector has incorrect size: " + std::to_string(scale_.size());
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
if (scale_[0] < 0) {
|
|
||||||
std::string err_msg = "RandomCropDecodeResize: invalid scale, min scale must be greater than or equal to 0, got: " +
|
|
||||||
std::to_string(scale_[0]);
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
if (scale_[1] <= 0) {
|
|
||||||
std::string err_msg =
|
|
||||||
"RandomCropDecodeResize: invalid scale, max scale must be greater than 0, got: " + std::to_string(scale_[1]);
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
if (scale_[0] > scale_[1]) {
|
|
||||||
std::string err_msg = "RandomCropDecodeResize: scale should be in (min,max) format. Got (max,min).";
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
// ratio
|
|
||||||
if (ratio_.empty() || ratio_.size() != 2) {
|
|
||||||
std::string err_msg = "RandomCropDecodeResize: ratio vector has incorrect size: " + std::to_string(ratio_.size());
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
for (int32_t i = 0; i < ratio_.size(); ++i) {
|
|
||||||
if (ratio_[i] <= 0) {
|
|
||||||
std::string err_msg =
|
|
||||||
"RandomCropDecodeResize: invalid ratio, ratio must be greater than 0, got: " + std::to_string(ratio_[i]);
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (ratio_[0] > ratio_[1]) {
|
|
||||||
std::string err_msg = "RandomCropDecodeResize: ratio should be in (min,max) format. Got (max,min).";
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
// max_attempts
|
|
||||||
if (max_attempts_ < 1) {
|
|
||||||
std::string err_msg =
|
|
||||||
"RandomCropDecodeResize: max_attempts must be greater than or equal to 1, got: " + std::to_string(max_attempts_);
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
|
std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
|
||||||
int32_t crop_height = size_[0];
|
int32_t crop_height = size_[0];
|
||||||
|
@ -1352,6 +1287,9 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
|
||||||
return tensor_op;
|
return tensor_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(const RandomResizedCropOperation &base)
|
||||||
|
: RandomResizedCropOperation(base) {}
|
||||||
|
|
||||||
// RandomCropWithBBoxOperation
|
// RandomCropWithBBoxOperation
|
||||||
RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding,
|
RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding,
|
||||||
bool pad_if_needed, std::vector<uint8_t> fill_value,
|
bool pad_if_needed, std::vector<uint8_t> fill_value,
|
||||||
|
@ -1574,62 +1512,56 @@ RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size
|
||||||
Status RandomResizedCropOperation::ValidateParams() {
|
Status RandomResizedCropOperation::ValidateParams() {
|
||||||
// size
|
// size
|
||||||
if (size_.size() != 2 && size_.size() != 1) {
|
if (size_.size() != 2 && size_.size() != 1) {
|
||||||
std::string err_msg =
|
std::string err_msg = Name() + ": size must be a vector of one or two values, got: " + std::to_string(size_.size());
|
||||||
"RandomResizedCrop: size must be a vector of one or two values, got: " + std::to_string(size_.size());
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (size_[0] <= 0 || (size_.size() == 2 && size_[1] <= 0)) {
|
if (size_[0] <= 0 || (size_.size() == 2 && size_[1] <= 0)) {
|
||||||
std::string err_msg = "RandomResizedCrop: size must only contain positive integers.";
|
std::string err_msg = Name() + ": size must only contain positive integers.";
|
||||||
MS_LOG(ERROR) << "RandomResizedCrop: size must only contain positive integers, got: " << size_;
|
MS_LOG(ERROR) << Name() + ": size must only contain positive integers, got: " << size_;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
// scale
|
// scale
|
||||||
if (scale_.size() != 2) {
|
if (scale_.size() != 2) {
|
||||||
std::string err_msg =
|
std::string err_msg = Name() + ": scale must be a vector of two values, got: " + std::to_string(scale_.size());
|
||||||
"RandomResizedCrop: scale must be a vector of two values, got: " + std::to_string(scale_.size());
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (scale_[0] < 0) {
|
if (scale_[0] < 0) {
|
||||||
std::string err_msg = "RandomResizedCrop: min scale must be greater than or equal to 0.";
|
std::string err_msg = Name() + ": min scale must be greater than or equal to 0.";
|
||||||
MS_LOG(ERROR) << "RandomResizedCrop: min scale must be greater than or equal to 0, got: " +
|
MS_LOG(ERROR) << Name() + ": min scale must be greater than or equal to 0, got: " + std::to_string(scale_[0]);
|
||||||
std::to_string(scale_[0]);
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (scale_[1] <= 0) {
|
if (scale_[1] <= 0) {
|
||||||
std::string err_msg = "RandomResizedCrop: max scale must be greater than 0.";
|
std::string err_msg = Name() + ": max scale must be greater than 0.";
|
||||||
MS_LOG(ERROR) << "RandomResizedCrop: max scale must be greater than 0, got: " + std::to_string(scale_[1]);
|
MS_LOG(ERROR) << Name() + ": max scale must be greater than 0, got: " + std::to_string(scale_[1]);
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (scale_[1] < scale_[0]) {
|
if (scale_[1] < scale_[0]) {
|
||||||
std::string err_msg = "RandomResizedCrop: scale must have a size of two in the format of (min, max).";
|
std::string err_msg = Name() + ": scale must have a size of two in the format of (min, max).";
|
||||||
MS_LOG(ERROR) << "RandomResizedCrop: scale must have a size of two in the format of (min, max), but got: "
|
MS_LOG(ERROR) << Name() + ": scale must have a size of two in the format of (min, max), but got: " << scale_;
|
||||||
<< scale_;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
// ratio
|
// ratio
|
||||||
if (ratio_.size() != 2) {
|
if (ratio_.size() != 2) {
|
||||||
std::string err_msg =
|
std::string err_msg = Name() + ": ratio must be a vector of two values, got: " + std::to_string(ratio_.size());
|
||||||
"RandomResizedCrop: ratio must be a vector of two values, got: " + std::to_string(ratio_.size());
|
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (ratio_[0] <= 0 || ratio_[1] <= 0) {
|
if (ratio_[0] <= 0 || ratio_[1] <= 0) {
|
||||||
std::string err_msg = "RandomResizedCrop: ratio must be greater than 0.";
|
std::string err_msg = Name() + ": ratio must be greater than 0.";
|
||||||
MS_LOG(ERROR) << "RandomResizedCrop: ratio must be greater than 0, got: " << ratio_;
|
MS_LOG(ERROR) << Name() + ": ratio must be greater than 0, got: " << ratio_;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
if (ratio_[1] < ratio_[0]) {
|
if (ratio_[1] < ratio_[0]) {
|
||||||
std::string err_msg = "RandomResizedCrop: ratio must have a size of two in the format of (min, max).";
|
std::string err_msg = Name() + ": ratio must have a size of two in the format of (min, max).";
|
||||||
MS_LOG(ERROR) << "RandomResizedCrop: ratio must have a size of two in the format of (min, max), but got: "
|
MS_LOG(ERROR) << Name() + ": ratio must have a size of two in the format of (min, max), but got: " << ratio_;
|
||||||
<< ratio_;
|
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
// max_attempts
|
// max_attempts
|
||||||
if (max_attempts_ < 1) {
|
if (max_attempts_ < 1) {
|
||||||
std::string err_msg =
|
std::string err_msg =
|
||||||
"RandomResizedCrop: max_attempts must be greater than or equal to 1, got: " + std::to_string(max_attempts_);
|
Name() + ": max_attempts must be greater than or equal to 1, got: " + std::to_string(max_attempts_);
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -515,17 +515,6 @@ Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vec
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TreeGetters::InternalInit(int8_t type) {
|
|
||||||
if (init_flag_) return Status::OK();
|
|
||||||
tree_adapter_->SetPrePassOverride([&type](OptPass pre) {
|
|
||||||
pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type)));
|
|
||||||
return pre;
|
|
||||||
});
|
|
||||||
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
|
||||||
if (s.IsOk()) init_flag_ = true;
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TreeGetters::InternalInit() {
|
Status TreeGetters::InternalInit() {
|
||||||
if (init_flag_) return Status::OK();
|
if (init_flag_) return Status::OK();
|
||||||
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
||||||
|
@ -535,7 +524,7 @@ Status TreeGetters::InternalInit() {
|
||||||
|
|
||||||
Status TreeGetters::GetFirstRowShapeAndType() {
|
Status TreeGetters::GetFirstRowShapeAndType() {
|
||||||
RETURN_OK_IF_TRUE(first_row_obtained_);
|
RETURN_OK_IF_TRUE(first_row_obtained_);
|
||||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
RETURN_IF_NOT_OK(InternalInit());
|
||||||
TensorRow first_row;
|
TensorRow first_row;
|
||||||
RETURN_IF_NOT_OK(GetRow(&first_row));
|
RETURN_IF_NOT_OK(GetRow(&first_row));
|
||||||
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
|
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
|
||||||
|
@ -572,11 +561,6 @@ Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
|
||||||
Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
|
Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
|
||||||
std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
|
std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
|
||||||
tree_adapters_.push_back(tree_adapter);
|
tree_adapters_.push_back(tree_adapter);
|
||||||
tree_adapter->SetPrePassOverride([](OptPass pre) {
|
|
||||||
pre.push_back(
|
|
||||||
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
|
|
||||||
return pre;
|
|
||||||
});
|
|
||||||
RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
|
RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
|
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
|
||||||
|
|
|
@ -199,7 +199,6 @@ class TreeGetters : public TreeConsumer {
|
||||||
bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter
|
bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter
|
||||||
bool init_flag_; // indicate whether the tree has initialized
|
bool init_flag_; // indicate whether the tree has initialized
|
||||||
|
|
||||||
Status InternalInit(int8_t type);
|
|
||||||
Status InternalInit();
|
Status InternalInit();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -40,12 +40,9 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
// Constructor
|
// Constructor
|
||||||
ExecutionTree::ExecutionTree() : id_count_(0), pre_pass_override_(nullptr) {
|
ExecutionTree::ExecutionTree() : id_count_(0), tree_state_(kDeTStateInit), prepare_flags_(kDePrepNone) {
|
||||||
tg_ = std::make_unique<TaskGroup>();
|
tg_ = std::make_unique<TaskGroup>();
|
||||||
tree_state_ = kDeTStateInit;
|
|
||||||
prepare_flags_ = kDePrepNone;
|
|
||||||
profiling_manager_ = std::make_unique<ProfilingManager>(this);
|
profiling_manager_ = std::make_unique<ProfilingManager>(this);
|
||||||
optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false;
|
|
||||||
#if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE))
|
#if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE))
|
||||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||||
rank_id_ = cfg->rank_id();
|
rank_id_ = cfg->rank_id();
|
||||||
|
@ -275,10 +272,6 @@ Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
|
||||||
// Pre optimization compulsory transformation
|
// Pre optimization compulsory transformation
|
||||||
RETURN_IF_NOT_OK(this->PreAction());
|
RETURN_IF_NOT_OK(this->PreAction());
|
||||||
|
|
||||||
// If optional optimizations are enabled
|
|
||||||
if (optimize_) {
|
|
||||||
RETURN_IF_NOT_OK(this->Optimize());
|
|
||||||
}
|
|
||||||
// Post optimization compulsory transformation
|
// Post optimization compulsory transformation
|
||||||
RETURN_IF_NOT_OK(this->PostAction());
|
RETURN_IF_NOT_OK(this->PostAction());
|
||||||
|
|
||||||
|
@ -302,14 +295,6 @@ Status ExecutionTree::PreAction() {
|
||||||
pre_actions.push_back(std::make_unique<RemovalPass>());
|
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// this offers a way to override the preset optimization pass with customized ones
|
|
||||||
// this is used when certain nodes are removed for tree getters
|
|
||||||
if (pre_pass_override_) {
|
|
||||||
MS_LOG(INFO) << "Default pre optimization passes is being overridden,"
|
|
||||||
<< " number of passes before the override:" << pre_actions.size() << ".";
|
|
||||||
pre_actions = pre_pass_override_(std::move(pre_actions));
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops.";
|
MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops.";
|
||||||
|
|
||||||
// Apply pre action passes
|
// Apply pre action passes
|
||||||
|
@ -343,22 +328,6 @@ Status ExecutionTree::PostAction() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExecutionTree::Optimize() {
|
|
||||||
// Vector of optimizations, currently only 1, add more as necessary
|
|
||||||
OptPass optimizations;
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
optimizations.push_back(std::make_unique<TensorOpFusionPass>());
|
|
||||||
#endif
|
|
||||||
// vector of flags for each optimization
|
|
||||||
std::vector<bool> modified(optimizations.size(), false);
|
|
||||||
for (auto i = 0; i < optimizations.size(); i++) {
|
|
||||||
auto m = false;
|
|
||||||
optimizations[i]->Run(this, &m);
|
|
||||||
modified[i] = m;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||||
// it ready for execution.
|
// it ready for execution.
|
||||||
|
|
|
@ -192,10 +192,6 @@ class ExecutionTree {
|
||||||
// @return Status The status code returned
|
// @return Status The status code returned
|
||||||
Status PostAction();
|
Status PostAction();
|
||||||
|
|
||||||
// Optimization transformation/action, optional.
|
|
||||||
// @return Status The status code returned
|
|
||||||
Status Optimize();
|
|
||||||
|
|
||||||
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
|
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||||
// it ready for execution.
|
// it ready for execution.
|
||||||
|
@ -240,29 +236,10 @@ class ExecutionTree {
|
||||||
// Getter for profiling manager, no ownership
|
// Getter for profiling manager, no ownership
|
||||||
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
|
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
|
||||||
|
|
||||||
// Set optional optimization if tree has not been prepared yet
|
|
||||||
Status SetOptimize(bool value) {
|
|
||||||
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) {
|
|
||||||
std::string optimize = (optimize_ == true) ? "true" : "false";
|
|
||||||
std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize;
|
|
||||||
RETURN_STATUS_UNEXPECTED(msg);
|
|
||||||
} else {
|
|
||||||
optimize_ = value;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Optional optimizations status
|
|
||||||
bool OptimizationEnabled() const { return optimize_; }
|
|
||||||
|
|
||||||
// Getter function to get the total number of epochs to be run on this tree.
|
// Getter function to get the total number of epochs to be run on this tree.
|
||||||
// @return total number of epochs
|
// @return total number of epochs
|
||||||
int32_t num_epochs() { return num_epochs_; }
|
int32_t num_epochs() { return num_epochs_; }
|
||||||
|
|
||||||
// set the function ptr that overrides the pre-pass which allows caller to adjust the existing pre_pass and
|
|
||||||
// introduce new passes. E.g. caller can override the num_epoch in EpochInjectionPass
|
|
||||||
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A helper functions for doing the recursive printing
|
// A helper functions for doing the recursive printing
|
||||||
// @param dataset_op - The dataset op to print
|
// @param dataset_op - The dataset op to print
|
||||||
|
@ -279,8 +256,6 @@ class ExecutionTree {
|
||||||
TreeState tree_state_; // Tracking the current tree state
|
TreeState tree_state_; // Tracking the current tree state
|
||||||
int32_t num_epochs_; // Total number of epochs to run for this tree
|
int32_t num_epochs_; // Total number of epochs to run for this tree
|
||||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||||
bool optimize_; // Flag to enable optional optimizations
|
|
||||||
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
|
||||||
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
|
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
|
||||||
#if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE))
|
#if defined(NUMA_ENABLED) && (defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE))
|
||||||
// This rank_id is for numa and device_queue, one process work with only one rank_id,
|
// This rank_id is for numa and device_queue, one process work with only one rank_id,
|
||||||
|
|
|
@ -115,5 +115,10 @@ Status MapNode::AcceptAfter(IRNodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->VisitAfter(shared_from_base<MapNode>(), modified);
|
return p->VisitAfter(shared_from_base<MapNode>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MapNode::setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations) {
|
||||||
|
operations_ = operations;
|
||||||
|
}
|
||||||
|
std::vector<std::shared_ptr<TensorOperation>> MapNode::operations() { return operations_; }
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -75,8 +75,19 @@ class MapNode : public DatasetNode {
|
||||||
/// \return Status of the node visit
|
/// \return Status of the node visit
|
||||||
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
Status AcceptAfter(IRNodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief clear all callbacks
|
||||||
|
void ClearCallbacks() { callbacks_.clear(); }
|
||||||
|
|
||||||
|
/// \brief getter to get all tensor operations
|
||||||
|
std::vector<std::shared_ptr<TensorOperation>> operations();
|
||||||
|
|
||||||
|
/// \brief setter to set all tensor operations
|
||||||
|
void setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||||
|
|
||||||
|
private:
|
||||||
std::vector<std::string> input_columns_;
|
std::vector<std::string> input_columns_;
|
||||||
std::vector<std::string> output_columns_;
|
std::vector<std::string> output_columns_;
|
||||||
std::vector<std::string> project_columns_;
|
std::vector<std::string> project_columns_;
|
||||||
|
|
|
@ -13,45 +13,53 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||||
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
#include "minddata/dataset/include/transforms.h"
|
||||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
#include "minddata/dataset/include/vision.h"
|
||||||
|
#include "minddata/dataset/include/vision_lite.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
Status TensorOpFusionPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
Status TensorOpFusionPass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
||||||
// Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp
|
std::vector<std::shared_ptr<TensorOperation>> ops = node->operations();
|
||||||
// Abstract into a more general member function that can find any pattern, expressed
|
|
||||||
// by regular expressions, for instance.
|
|
||||||
// Add a list of optimisation policies. For now, just this lambda
|
|
||||||
auto FindPattern = [](auto &tfuncs) {
|
|
||||||
auto it =
|
|
||||||
std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; });
|
|
||||||
auto next = it + 1;
|
|
||||||
if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) {
|
|
||||||
return it;
|
|
||||||
} else {
|
|
||||||
return tfuncs.end();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto &tfuncs = node->TFuncs();
|
// start temporary code, to deal with pre-built TensorOperation
|
||||||
auto it = FindPattern(tfuncs);
|
std::vector<std::string> pattern = {kDecodeOp, kRandomCropAndResizeOp};
|
||||||
if (it != tfuncs.end()) {
|
auto itr = std::search(ops.begin(), ops.end(), pattern.begin(), pattern.end(),
|
||||||
auto next = it + 1;
|
[](auto op, const std::string &nm) { return op->Name() == nm; });
|
||||||
auto op = static_cast<RandomCropAndResizeOp *>(next->get());
|
if (itr != ops.end()) {
|
||||||
*it = std::static_pointer_cast<TensorOp>(std::make_shared<RandomCropDecodeResizeOp>(*op));
|
MS_LOG(WARNING) << "Fusing pre-build Decode and RandomCropResize into one pre-build.";
|
||||||
tfuncs.erase(next);
|
auto op = dynamic_cast<RandomCropAndResizeOp *>((*(itr + 1))->Build().get());
|
||||||
}
|
(*itr) = std::make_shared<transforms::PreBuiltOperation>(std::make_shared<RandomCropDecodeResizeOp>(*op));
|
||||||
if (modified != nullptr) {
|
ops.erase(itr + 1);
|
||||||
|
node->setOperations(ops);
|
||||||
*modified = true;
|
*modified = true;
|
||||||
} else {
|
return Status::OK();
|
||||||
RETURN_STATUS_UNEXPECTED("modified is nullptr");
|
} // end of temporary code, needs to be deleted when tensorOperation's pybind completes
|
||||||
}
|
|
||||||
|
// logic below is for non-prebuilt TensorOperation
|
||||||
|
pattern = {vision::kDecodeOperation, vision::kRandomResizedCropOperation};
|
||||||
|
itr = std::search(ops.begin(), ops.end(), pattern.begin(), pattern.end(),
|
||||||
|
[](auto op, const std::string &nm) { return op->Name() == nm; });
|
||||||
|
|
||||||
|
// return here if no pattern is found
|
||||||
|
RETURN_OK_IF_TRUE(itr == ops.end());
|
||||||
|
auto *op = dynamic_cast<vision::RandomResizedCropOperation *>((itr + 1)->get());
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(op);
|
||||||
|
// fuse the two ops
|
||||||
|
(*itr) = std::make_shared<vision::RandomCropDecodeResizeOperation>(*op);
|
||||||
|
ops.erase(itr + 1);
|
||||||
|
node->setOperations(ops);
|
||||||
|
*modified = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -25,12 +25,12 @@ namespace dataset {
|
||||||
/// \class TensorOpFusionPass tensor_op_fusion_pass.h
|
/// \class TensorOpFusionPass tensor_op_fusion_pass.h
|
||||||
/// \brief And optional optimization pass identifying and fusing
|
/// \brief And optional optimization pass identifying and fusing
|
||||||
/// tensor ops within MapOp
|
/// tensor ops within MapOp
|
||||||
class TensorOpFusionPass : public NodePass {
|
class TensorOpFusionPass : public IRNodePass {
|
||||||
/// \brief Identifies and fuses tensor ops within MapOp
|
/// \brief Identifies and fuses tensor ops within MapOp
|
||||||
/// \param[in] node The node being visited
|
/// \param[in] node The node being visited
|
||||||
/// \param[inout] *modified indicates whether the node has been visited
|
/// \param[inout] *modified indicates whether the node has been visited
|
||||||
/// \return Status The status code returned
|
/// \return Status The status code returned
|
||||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,52 +15,13 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||||
#include "minddata/dataset/engine/execution_tree.h"
|
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
|
||||||
nodes_to_remove_.push_back(node);
|
Status GetterPass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
||||||
|
node->ClearCallbacks();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|
||||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
|
|
||||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
|
||||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
|
||||||
nodes_to_clear_callback_.push_back(node);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef ENABLE_PYTHON
|
|
||||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
|
||||||
if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
|
||||||
RETURN_IF_NOT_OK(pass_.Run(tree, modified));
|
|
||||||
|
|
||||||
// currently the getter pass only disables call_back from the execution tree
|
|
||||||
|
|
||||||
// clear the callback for selected ops (map when its GetOutputType/Shape)
|
|
||||||
for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks();
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
|
||||||
#include "minddata/dataset/engine/opt/pass.h"
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -28,48 +27,16 @@ namespace dataset {
|
||||||
class DatasetOp;
|
class DatasetOp;
|
||||||
|
|
||||||
/// \class GetterPass
|
/// \class GetterPass
|
||||||
/// \brief This is a tree pass that will remove nodes or clears the callback in MapOp
|
/// \brief This is a tree pass that will for now only clear the callback in MapOp to prevent hang
|
||||||
class GetterPass : public TreePass {
|
class GetterPass : public IRNodePass {
|
||||||
public:
|
public:
|
||||||
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
|
/// \brief Default Constructor
|
||||||
/// \brief Constructor
|
GetterPass() = default;
|
||||||
explicit GetterPass(GetterType tp) : pass_(tp) {}
|
|
||||||
|
|
||||||
/// \brief default copy Constructor
|
/// \brief Default Destructor
|
||||||
explicit GetterPass(const GetterPass &) = default;
|
|
||||||
|
|
||||||
/// \brief Destructor
|
|
||||||
~GetterPass() = default;
|
~GetterPass() = default;
|
||||||
|
|
||||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
|
||||||
|
|
||||||
private:
|
|
||||||
/// \class GetterNodes, this is a nested class which is owned via composition by the outter class to identify nodes
|
|
||||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
|
||||||
class GetterNodes : public NodePass {
|
|
||||||
public:
|
|
||||||
/// \brief Constructor
|
|
||||||
explicit GetterNodes(GetterType tp) : type_(tp) {}
|
|
||||||
|
|
||||||
~GetterNodes() = default;
|
|
||||||
|
|
||||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
|
||||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
|
||||||
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); }
|
|
||||||
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
|
||||||
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
|
||||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
|
||||||
|
|
||||||
#ifdef ENABLE_PYTHON
|
|
||||||
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
GetterType type_;
|
|
||||||
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
|
|
||||||
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
|
|
||||||
};
|
|
||||||
// outer class needs only to own the inner class object since it automatically has access to its private variables
|
|
||||||
GetterNodes pass_;
|
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace dataset {
|
||||||
/// \class InputValidationPass
|
/// \class InputValidationPass
|
||||||
/// \brief This is a parse pass that validates input parameters of the IR tree.
|
/// \brief This is a parse pass that validates input parameters of the IR tree.
|
||||||
class InputValidationPass : public IRNodePass {
|
class InputValidationPass : public IRNodePass {
|
||||||
/// \brief Runs a validatation pass to check input parameters
|
/// \brief Runs a validation pass to check input parameters
|
||||||
/// \param[in] node The node being visited
|
/// \param[in] node The node being visited
|
||||||
/// \param[inout] *modified indicates whether the node has been visited
|
/// \param[inout] *modified indicates whether the node has been visited
|
||||||
/// \return Status code
|
/// \return Status code
|
||||||
|
|
|
@ -18,11 +18,13 @@
|
||||||
|
|
||||||
#include "minddata/dataset/core/client.h"
|
#include "minddata/dataset/core/client.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||||
|
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pass.h"
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
|
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
||||||
|
|
||||||
|
@ -38,11 +40,11 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
||||||
std::vector<std::unique_ptr<IRPass>> actions;
|
std::vector<std::unique_ptr<IRPass>> actions;
|
||||||
|
|
||||||
MS_LOG(INFO) << "Running pre pass loops.";
|
MS_LOG(INFO) << "Running pre pass loops.";
|
||||||
actions.push_back(std::make_unique<InputValidationPass>());
|
actions.emplace_back(std::make_unique<InputValidationPass>());
|
||||||
actions.push_back(std::make_unique<CacheValidationPass>());
|
actions.emplace_back(std::make_unique<CacheValidationPass>());
|
||||||
actions.push_back(std::make_unique<NodeRemovalPass>());
|
actions.emplace_back(std::make_unique<NodeRemovalPass>());
|
||||||
actions.push_back(std::make_unique<EpochCtrlPass>());
|
actions.emplace_back(std::make_unique<EpochCtrlPass>());
|
||||||
|
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
|
||||||
// Vector of flags for each action
|
// Vector of flags for each action
|
||||||
std::vector<bool> modified(actions.size(), false);
|
std::vector<bool> modified(actions.size(), false);
|
||||||
// Apply pre-pass actions
|
// Apply pre-pass actions
|
||||||
|
@ -59,16 +61,11 @@ Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) {
|
||||||
// Vector of optimizations
|
// Vector of optimizations
|
||||||
std::vector<std::unique_ptr<IRNodePass>> optimizations;
|
std::vector<std::unique_ptr<IRNodePass>> optimizations;
|
||||||
MS_LOG(INFO) << "Running optimization pass loops";
|
MS_LOG(INFO) << "Running optimization pass loops";
|
||||||
|
optimizations.emplace_back(std::make_unique<TensorOpFusionPass>());
|
||||||
// We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here.
|
|
||||||
|
|
||||||
// Vector of flags for each optimization
|
|
||||||
std::vector<bool> modified(optimizations.size(), false);
|
|
||||||
// Apply optimization pass actions
|
// Apply optimization pass actions
|
||||||
for (auto i = 0; i < optimizations.size(); i++) {
|
for (auto i = 0; i < optimizations.size(); i++) {
|
||||||
auto m = false;
|
bool modified = false;
|
||||||
RETURN_IF_NOT_OK(optimizations[i]->Run(ir, &m));
|
RETURN_IF_NOT_OK(optimizations[i]->Run(ir, &modified));
|
||||||
modified[i] = m;
|
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Optimization pass complete.";
|
MS_LOG(INFO) << "Optimization pass complete.";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -133,8 +130,6 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epoc
|
||||||
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op));
|
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op));
|
||||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||||
|
|
||||||
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
|
||||||
|
|
||||||
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
||||||
// on ExecutionTree to perform on IR tree.
|
// on ExecutionTree to perform on IR tree.
|
||||||
// Prepare the tree
|
// Prepare the tree
|
||||||
|
|
|
@ -66,9 +66,6 @@ class TreeAdapter {
|
||||||
// Set optional optimization pass
|
// Set optional optimization pass
|
||||||
void SetOptimize(bool value) { optimize_ = value; }
|
void SetOptimize(bool value) { optimize_ = value; }
|
||||||
|
|
||||||
// function to override override the pre-pass
|
|
||||||
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
|
|
||||||
|
|
||||||
// Optional optimizations status
|
// Optional optimizations status
|
||||||
bool OptimizationEnabled() const { return optimize_; }
|
bool OptimizationEnabled() const { return optimize_; }
|
||||||
|
|
||||||
|
@ -90,14 +87,13 @@ class TreeAdapter {
|
||||||
|
|
||||||
std::unique_ptr<DataBuffer> cur_db_;
|
std::unique_ptr<DataBuffer> cur_db_;
|
||||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||||
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
|
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
|
||||||
bool optimize_; // Flag to enable optional optimization pass
|
bool optimize_; // Flag to enable optional optimization pass
|
||||||
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
|
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
|
||||||
int32_t cur_batch_num_; // current batch number, used for profiling
|
int32_t cur_batch_num_; // current batch number, used for profiling
|
||||||
int32_t cur_connector_size_; // current connector size of root op, used for profiling
|
int32_t cur_connector_size_; // current connector size of root op, used for profiling
|
||||||
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
|
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
|
||||||
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
UsageFlag usage_; // usage of this tree adapter (type of consumer)
|
||||||
UsageFlag usage_; // usage of this tree adapter (type of consumer)
|
|
||||||
// State flags for the lifecycle of the tree
|
// State flags for the lifecycle of the tree
|
||||||
enum CompileState {
|
enum CompileState {
|
||||||
kCompileStateInit = 0, // The freshly initialized state
|
kCompileStateInit = 0, // The freshly initialized state
|
||||||
|
|
|
@ -204,7 +204,7 @@ class PreBuiltOperation : public TensorOperation {
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
std::string Name() const override { return kPreBuiltOperation; }
|
std::string Name() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<TensorOp> op_;
|
std::shared_ptr<TensorOp> op_;
|
||||||
|
|
|
@ -759,20 +759,25 @@ class RandomCropOperation : public TensorOperation {
|
||||||
BorderType padding_mode_;
|
BorderType padding_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RandomCropDecodeResizeOperation : public TensorOperation {
|
class RandomResizedCropOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio,
|
RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
||||||
InterpolationMode interpolation, int32_t max_attempts);
|
std::vector<float> ratio = {3. / 4., 4. / 3.},
|
||||||
|
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
||||||
|
int32_t max_attempts = 10);
|
||||||
|
|
||||||
~RandomCropDecodeResizeOperation() = default;
|
/// \brief default copy constructor
|
||||||
|
explicit RandomResizedCropOperation(const RandomResizedCropOperation &) = default;
|
||||||
|
|
||||||
|
~RandomResizedCropOperation() = default;
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> Build() override;
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
Status ValidateParams() override;
|
Status ValidateParams() override;
|
||||||
|
|
||||||
std::string Name() const override { return kRandomCropDecodeResizeOperation; }
|
std::string Name() const override { return kRandomResizedCropOperation; }
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
std::vector<float> scale_;
|
std::vector<float> scale_;
|
||||||
std::vector<float> ratio_;
|
std::vector<float> ratio_;
|
||||||
|
@ -780,6 +785,20 @@ class RandomCropDecodeResizeOperation : public TensorOperation {
|
||||||
int32_t max_attempts_;
|
int32_t max_attempts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RandomCropDecodeResizeOperation : public RandomResizedCropOperation {
|
||||||
|
public:
|
||||||
|
RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, std::vector<float> ratio,
|
||||||
|
InterpolationMode interpolation, int32_t max_attempts);
|
||||||
|
|
||||||
|
explicit RandomCropDecodeResizeOperation(const RandomResizedCropOperation &base);
|
||||||
|
|
||||||
|
~RandomCropDecodeResizeOperation() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kRandomCropDecodeResizeOperation; }
|
||||||
|
};
|
||||||
|
|
||||||
class RandomCropWithBBoxOperation : public TensorOperation {
|
class RandomCropWithBBoxOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
|
RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
|
||||||
|
@ -882,29 +901,6 @@ class RandomResizeWithBBoxOperation : public TensorOperation {
|
||||||
std::vector<int32_t> size_;
|
std::vector<int32_t> size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RandomResizedCropOperation : public TensorOperation {
|
|
||||||
public:
|
|
||||||
explicit RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
|
||||||
std::vector<float> ratio = {3. / 4., 4. / 3.},
|
|
||||||
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
|
||||||
int32_t max_attempts = 10);
|
|
||||||
|
|
||||||
~RandomResizedCropOperation() = default;
|
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> Build() override;
|
|
||||||
|
|
||||||
Status ValidateParams() override;
|
|
||||||
|
|
||||||
std::string Name() const override { return kRandomResizedCropOperation; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<int32_t> size_;
|
|
||||||
std::vector<float> scale_;
|
|
||||||
std::vector<float> ratio_;
|
|
||||||
InterpolationMode interpolation_;
|
|
||||||
int32_t max_attempts_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class RandomResizedCropWithBBoxOperation : public TensorOperation {
|
class RandomResizedCropWithBBoxOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit RandomResizedCropWithBBoxOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
explicit RandomResizedCropWithBBoxOperation(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
|
||||||
|
|
|
@ -16,14 +16,17 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "minddata/dataset/core/client.h"
|
|
||||||
#include "common/common.h"
|
#include "common/common.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "minddata/dataset/core/client.h"
|
||||||
#include "minddata/dataset/engine/execution_tree.h"
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||||
|
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
|
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
|
||||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
#include "minddata/dataset/include/transforms.h"
|
||||||
|
#include "minddata/dataset/include/vision.h"
|
||||||
|
#include "minddata/dataset/include/vision_lite.h"
|
||||||
|
|
||||||
using namespace mindspore::dataset;
|
using namespace mindspore::dataset;
|
||||||
using mindspore::LogStream;
|
using mindspore::LogStream;
|
||||||
|
@ -31,7 +34,6 @@ using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
class MindDataTestOptimizationPass : public UT::DatasetOpTesting {};
|
class MindDataTestOptimizationPass : public UT::DatasetOpTesting {};
|
||||||
|
|
||||||
|
|
||||||
TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) {
|
TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestAutoWorkerPass.";
|
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestAutoWorkerPass.";
|
||||||
|
|
||||||
|
@ -63,3 +65,41 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestAutoWorkerPass) {
|
||||||
MS_LOG(DEBUG) << batch->IRNode()->Name() << ": num_worker=" << batch->IRNode()->num_workers();
|
MS_LOG(DEBUG) << batch->IRNode()->Name() << ": num_worker=" << batch->IRNode()->num_workers();
|
||||||
MS_LOG(DEBUG) << map->IRNode()->Name() << ": num_worker=" << map->IRNode()->num_workers();
|
MS_LOG(DEBUG) << map->IRNode()->Name() << ": num_worker=" << map->IRNode()->num_workers();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestOptimizationPass, MindDataTestTensorFusionPass) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestTensorFusionPass.";
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> root =
|
||||||
|
ImageFolder(folder_path, false)->Map({vision::Decode(), vision::RandomResizedCrop({100})}, {"image"});
|
||||||
|
|
||||||
|
TensorOpFusionPass fusion_pass;
|
||||||
|
bool modified = false;
|
||||||
|
std::shared_ptr<MapNode> map_node = std::dynamic_pointer_cast<MapNode>(root->IRNode());
|
||||||
|
// no deepcopy is performed because this doesn't go through tree_adapter
|
||||||
|
fusion_pass.Run(root->IRNode(), &modified);
|
||||||
|
EXPECT_EQ(modified, true);
|
||||||
|
ASSERT_NE(map_node, nullptr);
|
||||||
|
auto fused_ops = map_node->operations();
|
||||||
|
ASSERT_EQ(fused_ops.size(), 1);
|
||||||
|
ASSERT_EQ(fused_ops[0]->Name(), vision::kRandomCropDecodeResizeOperation);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestOptimizationPass, MindDataTestTensorFusionPassPreBuiltTensorOperation) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestTensorFusionPassPreBuiltTensorOperation.";
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
// make prebuilt tensor operation
|
||||||
|
auto decode = std::make_shared<transforms::PreBuiltOperation>(vision::Decode()->Build());
|
||||||
|
auto resize = std::make_shared<transforms::PreBuiltOperation>(vision::RandomResizedCrop({100})->Build());
|
||||||
|
std::shared_ptr<Dataset> root = ImageFolder(folder_path, false)->Map({decode, resize}, {"image"});
|
||||||
|
|
||||||
|
TensorOpFusionPass fusion_pass;
|
||||||
|
bool modified = false;
|
||||||
|
std::shared_ptr<MapNode> map_node = std::dynamic_pointer_cast<MapNode>(root->IRNode());
|
||||||
|
// no deepcopy is performed because this doesn't go through tree_adapter
|
||||||
|
fusion_pass.Run(root->IRNode(), &modified);
|
||||||
|
EXPECT_EQ(modified, true);
|
||||||
|
ASSERT_NE(map_node, nullptr);
|
||||||
|
auto fused_ops = map_node->operations();
|
||||||
|
ASSERT_EQ(fused_ops.size(), 1);
|
||||||
|
ASSERT_EQ(fused_ops[0]->Name(), kRandomCropDecodeResizeOp);
|
||||||
|
}
|
||||||
|
|
|
@ -454,6 +454,38 @@ def test_callbacks_one_cb():
|
||||||
assert events3 == expected_events3
|
assert events3 == expected_events3
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_callback():
|
||||||
|
logger.info("test_clear_callback")
|
||||||
|
|
||||||
|
# this test case will test that callback is removed for get_dataset_size and output_shape/type
|
||||||
|
class FlagCallback(DSCallback):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(step_size=1)
|
||||||
|
self.flag = False
|
||||||
|
self.row_cnt = 0
|
||||||
|
|
||||||
|
def ds_begin(self, ds_run_context):
|
||||||
|
# if callback isn't removed in getter pass, this function will be called
|
||||||
|
self.flag = True
|
||||||
|
|
||||||
|
def ds_step_begin(self, ds_run_context):
|
||||||
|
self.row_cnt += 1
|
||||||
|
|
||||||
|
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||||
|
cb = FlagCallback()
|
||||||
|
# make sure variables are properly initialized before testing
|
||||||
|
assert not cb.flag and cb.row_cnt == 0
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=cb)
|
||||||
|
assert data.get_dataset_size() == 4
|
||||||
|
assert data.output_shapes() == [[]]
|
||||||
|
# make sure callback is never called by checking flag and row_cnt
|
||||||
|
assert not cb.flag and cb.row_cnt == 0
|
||||||
|
for _ in data.create_dict_iterator(num_epochs=1):
|
||||||
|
pass
|
||||||
|
# this ensure that callback is indeed called
|
||||||
|
assert cb.flag and cb.row_cnt == 4
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_callbacks_all_2cbs()
|
test_callbacks_all_2cbs()
|
||||||
test_callbacks_all_methods()
|
test_callbacks_all_methods()
|
||||||
|
@ -467,3 +499,4 @@ if __name__ == '__main__':
|
||||||
test_callbacks_one_cb()
|
test_callbacks_one_cb()
|
||||||
test_callbacks_non_sink_mismatch_size()
|
test_callbacks_non_sink_mismatch_size()
|
||||||
test_callbacks_train_end()
|
test_callbacks_train_end()
|
||||||
|
test_clear_callback()
|
||||||
|
|
Loading…
Reference in New Issue