!31523 [AutoParallel]fix_pipeline_parallell_opt_detection_bug

Merge pull request !31523 from lichen/fix_opt_detection_bug
This commit is contained in:
i-robot 2022-03-19 08:05:08 +00:00 committed by Gitee
commit fd40988681
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 20 additions and 15 deletions

View File

@ -1015,28 +1015,32 @@ void PipelineTransformer::RedundancyNode(const AnfNodePtr &node,
bool PipelineTransformer::IsRedundancyParameter(const AnfNodePtr &parameter) {
// RedundancyParameter: other stage's parameters included corresponding cloned parameters.
auto parameters = root_->parameters();
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!param_ptr->has_default()) {
return false;
}
auto param_name = param_ptr->name();
for (auto &param : parameters) {
if (ParameterIsCloned(param)) {
continue;
std::set<int64_t> stage_set;
if (!ParameterIsCloned(parameter)) {
stage_set = parameter_color_map.at(parameter);
} else {
auto parameters = root_->parameters();
auto param_name = param_ptr->name();
for (auto &param : parameters) {
if (ParameterIsCloned(param)) {
continue;
}
auto non_cloned_param = param->cast<ParameterPtr>();
if (param_name.find(non_cloned_param->name()) == std::string::npos) {
continue;
}
stage_set = parameter_color_map.at(param);
}
auto non_cloned_param = param->cast<ParameterPtr>();
if (param_name.find(non_cloned_param->name()) == std::string::npos) {
continue;
}
auto stage_set = parameter_color_map.at(param);
if (stage_set.empty()) {
return false;
}
return !stage_set.count(stage_);
}
return false;
if (stage_set.empty()) {
return false;
}
return !stage_set.count(stage_);
}
void PipelineTransformer::ElimParameter() {
@ -1046,6 +1050,7 @@ void PipelineTransformer::ElimParameter() {
if (!IsRedundancyParameter(parameter)) {
continue;
}
MS_LOG(DEBUG) << "Parameter:" << parameter->DebugString() << " is Redundancy.";
RedundancyNode(parameter, &make_tuple_map);
}
for (auto &temp : make_tuple_map) {