!2885 skip strategy ckpt save for reshape

Merge pull request !2885 from yao_yf/skip_reshape_strategy_ckpt
This commit is contained in:
mindspore-ci-bot 2020-07-07 11:17:22 +08:00 committed by Gitee
commit 7f37bfbb16
4 changed files with 5 additions and 1 deletions

View File

@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
constexpr char RESHAPEINFO[] = "ReshapeInfo";
void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());

View File

@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
constexpr char REQUIRES_GRAD[] = "requires_grad";
constexpr char PARAM_NAME[] = "name";
constexpr char RESHAPEINFO[] = "ReshapeInfo";
constexpr char RELU_TYPE[] = "relu";
constexpr char RELU6_TYPE[] = "relu6";

View File

@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue;
}
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
stra_map[param_name] = strategyPtr;

View File

@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
parallel_strategy_item->set_node_name(node_stra.first);
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
MS_EXCEPTION_IF_NULL(parallel_strategys);
MS_EXCEPTION_IF_NULL(node_stra.second);
parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage()));
for (auto &dims : node_stra.second->GetInputDim()) {
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();