forked from mindspore-Ecosystem/mindspore
!2885 skip strategy ckpt save for reshape
Merge pull request !2885 from yao_yf/skip_reshape_strategy_ckpt
This commit is contained in:
commit
7f37bfbb16
|
@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
|
|||
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
|
||||
constexpr char RESHAPEINFO[] = "ReshapeInfo";
|
||||
|
||||
void CostGraph::SetDeviceMemoryAndCostParameter() {
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
|
|
|
@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
|
|||
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
|
||||
constexpr char REQUIRES_GRAD[] = "requires_grad";
|
||||
constexpr char PARAM_NAME[] = "name";
|
||||
constexpr char RESHAPEINFO[] = "ReshapeInfo";
|
||||
|
||||
constexpr char RELU_TYPE[] = "relu";
|
||||
constexpr char RELU6_TYPE[] = "relu6";
|
||||
|
|
|
@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
if (operator_info) {
|
||||
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
StrategyPtr strategyPtr = operator_info->strategy();
|
||||
MS_EXCEPTION_IF_NULL(node->scope());
|
||||
stra_map[param_name] = strategyPtr;
|
||||
|
|
|
@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
|
|||
parallel_strategy_item->set_node_name(node_stra.first);
|
||||
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
|
||||
MS_EXCEPTION_IF_NULL(parallel_strategys);
|
||||
MS_EXCEPTION_IF_NULL(node_stra.second);
|
||||
parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage()));
|
||||
for (auto &dims : node_stra.second->GetInputDim()) {
|
||||
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
|
||||
|
|
Loading…
Reference in New Issue