strategy_ckpy_add_opt_param_r1.3

This commit is contained in:
yao_yf 2021-07-07 09:25:55 +08:00
parent 00149771ae
commit 0cf808e164
2 changed files with 14 additions and 3 deletions

View File

@ -2990,7 +2990,7 @@ bool IsGatherPInfo(const std::string &name) {
return false;
}
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
StrategyMap stra_map;
TensorInfoMap tensor_info_map;
ManualShapeMap manual_shape_map;
@ -3033,7 +3033,18 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
}
}
}
for (auto &cloned_parameter_node : root->parameters()) {
MS_EXCEPTION_IF_NULL(cloned_parameter_node);
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
if (!ParameterIsCloned(cloned_parameter_node)) {
continue;
}
std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
tensor_info_map[cloned_param_name] = cloned_param_layout;
}
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
}
@ -3783,7 +3794,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(all_nodes);
CheckpointStrategy(all_nodes, root);
}
// ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager);

View File

@ -140,7 +140,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
// main step of Parallel
bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);