forked from mindspore-Ecosystem/mindspore
!31523 [AutoParallel]fix_pipeline_parallell_opt_detection_bug
Merge pull request !31523 from lichen/fix_opt_detection_bug
This commit is contained in:
commit
fd40988681
|
@ -1015,28 +1015,32 @@ void PipelineTransformer::RedundancyNode(const AnfNodePtr &node,
|
|||
|
||||
bool PipelineTransformer::IsRedundancyParameter(const AnfNodePtr ¶meter) {
|
||||
// RedundancyParameter: other stage's parameters included corresponding cloned parameters.
|
||||
auto parameters = root_->parameters();
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (!param_ptr->has_default()) {
|
||||
return false;
|
||||
}
|
||||
auto param_name = param_ptr->name();
|
||||
for (auto ¶m : parameters) {
|
||||
if (ParameterIsCloned(param)) {
|
||||
continue;
|
||||
std::set<int64_t> stage_set;
|
||||
if (!ParameterIsCloned(parameter)) {
|
||||
stage_set = parameter_color_map.at(parameter);
|
||||
} else {
|
||||
auto parameters = root_->parameters();
|
||||
auto param_name = param_ptr->name();
|
||||
for (auto ¶m : parameters) {
|
||||
if (ParameterIsCloned(param)) {
|
||||
continue;
|
||||
}
|
||||
auto non_cloned_param = param->cast<ParameterPtr>();
|
||||
if (param_name.find(non_cloned_param->name()) == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
stage_set = parameter_color_map.at(param);
|
||||
}
|
||||
auto non_cloned_param = param->cast<ParameterPtr>();
|
||||
if (param_name.find(non_cloned_param->name()) == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
auto stage_set = parameter_color_map.at(param);
|
||||
if (stage_set.empty()) {
|
||||
return false;
|
||||
}
|
||||
return !stage_set.count(stage_);
|
||||
}
|
||||
return false;
|
||||
if (stage_set.empty()) {
|
||||
return false;
|
||||
}
|
||||
return !stage_set.count(stage_);
|
||||
}
|
||||
|
||||
void PipelineTransformer::ElimParameter() {
|
||||
|
@ -1046,6 +1050,7 @@ void PipelineTransformer::ElimParameter() {
|
|||
if (!IsRedundancyParameter(parameter)) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Parameter:" << parameter->DebugString() << " is Redundancy.";
|
||||
RedundancyNode(parameter, &make_tuple_map);
|
||||
}
|
||||
for (auto &temp : make_tuple_map) {
|
||||
|
|
Loading…
Reference in New Issue