From aff9aaee354b12f261787e67a77b983b49ba376e Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Mon, 14 Dec 2020 14:11:56 +0800 Subject: [PATCH] calculate used_devices in VirtualDataset --- .../frontend/parallel/ops_info/virtual_dataset_info.cc | 3 +++ mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc | 6 +----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc index c8fa56f5d4b..8fb0cc283db 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -71,6 +71,9 @@ Status VirtualDatasetInfo::InferDevMatrixShape() { if (stage_device_size_ > batch_split_num) { dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num); } + // Because 'VirtualDataSet' uses 'InitWithManualRepeatCalc' which does not calculates 'used_devices_', + // we calculate it here. + used_devices_ = batch_split_num; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 822daac1d0d..7baa7f0300c 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -350,8 +350,6 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & // Using CNode's UniqueIds to construct nodes Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueId to its operatorInfo std::map from_cnode_to_info; // The operator_infos in a loop @@ -370,7 +368,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; } - // Step 1 + for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators auto cnode = node->cast(); @@ -454,8 +452,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node // Using CNode's UniqueIdThroughCopys to construct nodes Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueIdThroughCopy to its operatorInfo std::map from_cnode_to_info; // The operator_infos in a loop