!9898 [Auto parallel] Fix the bug of claculating 'used_devices' in VirtualDataset

From: @xiaoda_zh
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2020-12-14 19:11:07 +08:00 committed by Gitee
commit e5af12f8ef
2 changed files with 4 additions and 5 deletions

View File

@ -71,6 +71,9 @@ Status VirtualDatasetInfo::InferDevMatrixShape() {
if (stage_device_size_ > batch_split_num) {
dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num);
}
// Because 'VirtualDataSet' uses 'InitWithManualRepeatCalc' which does not calculates 'used_devices_',
// we calculate it here.
used_devices_ = batch_split_num;
return SUCCESS;
}

View File

@ -350,8 +350,6 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
// Using CNode's UniqueIds to construct nodes
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@ -370,7 +368,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
// Step 1
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@ -454,8 +452,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
// Using CNode's UniqueIdThroughCopys to construct nodes
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop