!41861 Do not split graph in some cases.

Merge pull request !41861 from ZPaC/decouple2
This commit is contained in:
i-robot 2022-09-15 01:16:07 +00:00 committed by Gitee
commit a323ea00f4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 26 additions and 4 deletions

View File

@ -341,6 +341,7 @@ distributed::DistExecutionMode GenerateStrategy() {
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
#endif
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
MS_LOG(INFO) << "Current parallel mode is " << parallel_mode;
bool using_parallel = (parallel_mode != parallel::kStandalone) ? true : false;
// The conditions' priority is: EmbeddingCache > Parameter Server > General.
if (enable_embedding_cache) {
@ -352,6 +353,7 @@ distributed::DistExecutionMode GenerateStrategy() {
} else {
strategy = distributed::DistExecutionMode::kGeneralMode;
}
MS_LOG(INFO) << "Generated distributed strategy is " << strategy;
return strategy;
}
@ -1074,18 +1076,24 @@ void GraphSplitter::Run() {
DyeGraph();
// If all nodes are all on this process, no need to split the graph. So return.
if (!NeedSplitGraph()) {
MS_LOG(INFO) << "No need to build and split distributed graph.";
MS_LOG(INFO) << "All nodes are on this precoess so there's no need to build and split distributed graph.";
return;
}
// Step 2: Create exec_mode_ according to the current execution mode.
CreateExecutionMode();
// If this is general mode but no label is set, do not split graph to avoid unexpected optimizing out.
if (mode_ == distributed::DistExecutionMode::kGeneralMode && !GraphHasLabel(func_graph_)) {
MS_LOG(INFO) << "This graph has no label on it in general mode. So no need to split.";
return;
}
// Step 3: Prebuild the distributed graph before it gets split.
exec_mode_->PreBuildDistributedGraph();
if (!NeedSplitGraph()) {
MS_LOG(INFO) << "No need to build and split distributed graph.";
MS_LOG(INFO) << "All nodes are on this precoess so there's no need to build and split distributed graph.";
return;
}
@ -1150,6 +1158,8 @@ void GraphSplitter::CreateExecutionMode() {
exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
} else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
} else if (mode_ == distributed::DistExecutionMode::kParallelMode) {
exec_mode_ = std::make_unique<ParallelMode>(func_graph_, &node_labels_, rank_id_, role_);
} else if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
exec_mode_ = std::make_unique<GeneralMode>(func_graph_, &node_labels_, rank_id_, role_);
}

View File

@ -380,6 +380,15 @@ class GeneralMode : public DistributedExecutionMode {
~GeneralMode() = default;
};
// The mode applied when AutoParallel feature is enabled.
class ParallelMode : public DistributedExecutionMode {
public:
explicit ParallelMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
const std::string &role)
: DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
~ParallelMode() = default;
};
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
// cluster.
class GraphSplitter {

View File

@ -58,7 +58,8 @@ void CollectiveInitializer::InitCollective() {
#endif
} else {
if (!distributed::Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL.";
MS_LOG(EXCEPTION) << "Failed to initialize distributed execution for NCCL. Maybe the MindSpore cluster is not "
"successfully built. Please check schuduler and other nodes' log.";
}
}
CollectiveInitializer::instance().collective_inited_ = true;

View File

@ -37,6 +37,7 @@ from mindspore.common.parameter import Parameter
from mindspore.train import Model
from mindspore.common import set_seed
from mindspore.communication.management import init
from mindspore.parallel._ps_context import _is_role_worker
parser = argparse.ArgumentParser(description='test_ps_lenet')
parser.add_argument("--device_target", type=str, default="GPU")
@ -177,7 +178,8 @@ class NetFactory:
no_ps = self.no_ps_impl(ds2)
print(part_ps)
print(no_ps)
assert np.allclose(no_ps, part_ps, rtol=1.0e-4, atol=1.0e-4)
if _is_role_worker():
assert np.allclose(no_ps, part_ps, rtol=1.0e-4, atol=1.0e-4)
if __name__ == "__main__":