forked from mindspore-Ecosystem/mindspore
hange strategys of last nodes in eval/predict at auto parallel mode
This commit is contained in:
parent
2ba79a32a2
commit
19fe28cb9b
|
@ -230,7 +230,8 @@ void InitCostGraph() {
|
|||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
}
|
||||
|
||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
|
||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
|
||||
StrategyMap *stra_map) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto attrs = prim->attrs();
|
||||
|
@ -287,7 +288,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
// If no strategy has been configured for this operator, then candidate strategies are generated for
|
||||
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
|
||||
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
|
||||
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
|
||||
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) {
|
||||
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
||||
// BatchParallelInfo operator
|
||||
operator_info->ComputeBatchSplitFlagList();
|
||||
|
@ -304,10 +305,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
} else {
|
||||
// In this case, the configured strategy should be extracted to help setting cost
|
||||
StrategyPtr strategyPtr;
|
||||
if (load_strategy_from_ckpt) {
|
||||
strategyPtr = (*stra_map)[strategy_key_name];
|
||||
} else {
|
||||
if (is_last_nodes) {
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
|
||||
if (full_batch) {
|
||||
SetLastNodeStrategy(strategyPtr);
|
||||
}
|
||||
} else if (StrategyFound(attrs)) {
|
||||
strategyPtr = parallel::ExtractStrategy(attrs);
|
||||
} else {
|
||||
strategyPtr = (*stra_map)[strategy_key_name];
|
||||
}
|
||||
if (strategyPtr != nullptr) {
|
||||
if (prim->name() == RESHAPE) {
|
||||
|
@ -338,8 +345,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
}
|
||||
|
||||
// Using CNode's UniqueIds to construct nodes
|
||||
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
||||
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
|
||||
entire_costgraph = std::make_shared<CostGraph>();
|
||||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
// The map from CNode's UniqueId to its operatorInfo
|
||||
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
||||
// The operator_infos in a loop
|
||||
|
@ -353,7 +362,12 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> last_forward_node_ids;
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
}
|
||||
// Step 1
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -398,7 +412,9 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
|
||||
continue;
|
||||
}
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
|
||||
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
||||
last_forward_node_ids.end();
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -433,8 +449,10 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
}
|
||||
|
||||
// Using CNode's UniqueIdThroughCopys to construct nodes
|
||||
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
||||
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
|
||||
entire_costgraph = std::make_shared<CostGraph>();
|
||||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
|
||||
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
||||
// The operator_infos in a loop
|
||||
|
@ -448,6 +466,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
std::vector<std::string> last_forward_node_ids;
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -493,7 +516,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
continue;
|
||||
}
|
||||
// In this case, the corresponding OperatorInfo is not created, create the new one.
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
|
||||
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
||||
last_forward_node_ids.end();
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -1638,7 +1638,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
|
|||
return find;
|
||||
}
|
||||
|
||||
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, vector<std::string> *unique_ids) {
|
||||
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) {
|
||||
MS_EXCEPTION_IF_NULL(unique_ids);
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -1754,10 +1754,10 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|||
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
|
||||
<< " is empty, using batch parallel";
|
||||
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
|
||||
} else if (load_strategy_from_ckpt) {
|
||||
strategyPtr = stra_map[strategy_key_name];
|
||||
} else {
|
||||
} else if (StrategyFound(attrs)) {
|
||||
strategyPtr = ExtractStrategy(attrs);
|
||||
} else {
|
||||
strategyPtr = stra_map[strategy_key_name];
|
||||
}
|
||||
if (strategyPtr != nullptr) {
|
||||
if (is_last_nodes && full_batch) {
|
||||
|
|
|
@ -165,6 +165,10 @@ bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter);
|
|||
|
||||
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
|
||||
const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index);
|
||||
|
||||
void SetLastNodeStrategy(const StrategyPtr strategyPtr);
|
||||
|
||||
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -67,3 +67,20 @@ def test_train_and_eval():
|
|||
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
def test_train_and_eval_auto():
|
||||
context.set_context(save_graphs=True, mode=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16)
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = ((4, 4),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
eval_net = EvalNet(net, strategy2=strategy2)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
|
||||
|
||||
eval_net.set_train(mode=False)
|
||||
eval_net.set_auto_parallel()
|
||||
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue