hange strategys of last nodes in eval/predict at auto parallel mode

This commit is contained in:
yao_yf 2020-12-10 21:43:05 +08:00
parent 2ba79a32a2
commit 19fe28cb9b
4 changed files with 60 additions and 14 deletions

View File

@ -230,7 +230,8 @@ void InitCostGraph() {
entire_costgraph->SetDeviceMemoryAndCostParameter();
}
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode);
auto attrs = prim->attrs();
@ -287,7 +288,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList();
@ -304,10 +305,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
} else {
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr;
if (load_strategy_from_ckpt) {
strategyPtr = (*stra_map)[strategy_key_name];
} else {
if (is_last_nodes) {
bool full_batch = ParallelContext::GetInstance()->full_batch();
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
if (full_batch) {
SetLastNodeStrategy(strategyPtr);
}
} else if (StrategyFound(attrs)) {
strategyPtr = parallel::ExtractStrategy(attrs);
} else {
strategyPtr = (*stra_map)[strategy_key_name];
}
if (strategyPtr != nullptr) {
if (prim->name() == RESHAPE) {
@ -338,8 +345,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
}
// Using CNode's UniqueIds to construct nodes
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@ -353,7 +362,12 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
std::vector<std::string> last_forward_node_ids;
if (!root->has_flag(TRAINING)) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
// Step 1
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@ -398,7 +412,9 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
continue;
}
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
if (operator_info == nullptr) {
return FAILED;
}
@ -433,8 +449,10 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
}
// Using CNode's UniqueIdThroughCopys to construct nodes
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// The operator_infos in a loop
@ -448,6 +466,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
std::vector<std::string> last_forward_node_ids;
if (!root->has_flag(TRAINING)) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@ -493,7 +516,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
continue;
}
// In this case, the corresponding OperatorInfo is not created, create the new one.
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
if (operator_info == nullptr) {
return FAILED;
}

View File

@ -1638,7 +1638,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
return find;
}
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, vector<std::string> *unique_ids) {
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) {
MS_EXCEPTION_IF_NULL(unique_ids);
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
@ -1754,10 +1754,10 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
} else if (load_strategy_from_ckpt) {
strategyPtr = stra_map[strategy_key_name];
} else {
} else if (StrategyFound(attrs)) {
strategyPtr = ExtractStrategy(attrs);
} else {
strategyPtr = stra_map[strategy_key_name];
}
if (strategyPtr != nullptr) {
if (is_last_nodes && full_batch) {

View File

@ -165,6 +165,10 @@ bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter);
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index);
void SetLastNodeStrategy(const StrategyPtr strategyPtr);
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
} // namespace parallel
} // namespace mindspore

View File

@ -67,3 +67,20 @@ def test_train_and_eval():
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
context.reset_auto_parallel_context()
def test_train_and_eval_auto():
context.set_context(save_graphs=True, mode=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16)
strategy1 = ((4, 4), (4, 4))
strategy2 = ((4, 4),)
net = Net(_w1, strategy1, strategy2)
eval_net = EvalNet(net, strategy2=strategy2)
net.set_auto_parallel()
net.set_train()
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
eval_net.set_train(mode=False)
eval_net.set_auto_parallel()
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
context.reset_auto_parallel_context()