forked from mindspore-Ecosystem/mindspore
strategy_ckpy_add_opt_param_r1.3
This commit is contained in:
parent
00149771ae
commit
0cf808e164
|
@ -2990,7 +2990,7 @@ bool IsGatherPInfo(const std::string &name) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
StrategyMap stra_map;
|
||||
TensorInfoMap tensor_info_map;
|
||||
ManualShapeMap manual_shape_map;
|
||||
|
@ -3033,7 +3033,18 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
}
|
||||
for (auto &cloned_parameter_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node);
|
||||
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
if (!ParameterIsCloned(cloned_parameter_node)) {
|
||||
continue;
|
||||
}
|
||||
std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
||||
auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
|
||||
tensor_info_map[cloned_param_name] = cloned_param_layout;
|
||||
}
|
||||
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
||||
}
|
||||
|
@ -3783,7 +3794,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
|
||||
// save strategy as checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||
CheckpointStrategy(all_nodes);
|
||||
CheckpointStrategy(all_nodes, root);
|
||||
}
|
||||
// ForwardCommunication BackwardCommunication TensorRedistribution
|
||||
ParallelCommunication(root, all_nodes, manager);
|
||||
|
|
|
@ -140,7 +140,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
|
||||
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
|
||||
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
|
||||
|
||||
// main step of Parallel
|
||||
bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);
|
||||
|
|
Loading…
Reference in New Issue