forked from mindspore-Ecosystem/mindspore
!1187 Checkpoint and restore parameter's shape
Merge pull request !1187 from yangzhenzhang/ckpt-and-restore-parameter-shape
This commit is contained in:
commit
3b6de89368
|
@ -22,12 +22,15 @@
|
|||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "parallel/device_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
static std::map<std::string, std::vector<int>> param_shapes;
|
||||
|
||||
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
|
||||
AUTO_PARALLEL};
|
||||
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
||||
|
@ -136,5 +139,56 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
|
|||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelParameterContextInit(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
|
||||
return;
|
||||
}
|
||||
param_shapes.clear();
|
||||
}
|
||||
|
||||
// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
AbstractBasePtr ptr) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) ||
|
||||
func_graph->flags()[TRAINING]) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto iter = param_shapes.find(param_node->name());
|
||||
if (iter == param_shapes.end()) {
|
||||
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
|
||||
return;
|
||||
}
|
||||
std::vector<int> shape = iter->second;
|
||||
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
|
||||
ptr->set_shape(base_shape);
|
||||
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
}
|
||||
|
||||
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
const AbstractBasePtr &ptr) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
|
||||
auto ret = param_shapes.try_emplace(param_node->name(), shape);
|
||||
if (!ret.second) {
|
||||
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,6 +26,9 @@
|
|||
#include "parallel/ops_info/ops_utils.h"
|
||||
#include "parallel/status.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "debug/info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -38,6 +41,8 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
|
|||
constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
|
||||
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
|
||||
|
||||
constexpr char TRAINING[] = "training";
|
||||
|
||||
class ParallelContext {
|
||||
public:
|
||||
~ParallelContext() = default;
|
||||
|
@ -114,6 +119,12 @@ class ParallelContext {
|
|||
std::string strategy_ckpt_load_file_;
|
||||
std::string strategy_ckpt_save_file_;
|
||||
};
|
||||
|
||||
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
|
||||
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
AbstractBasePtr ptr);
|
||||
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
const AbstractBasePtr &ptr);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "parallel/costmodel_context.h"
|
||||
#include "parallel/context.h"
|
||||
#include "pipeline/pass.h"
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
|
@ -217,6 +218,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
FuncGraphPtr func_graph = res->func_graph();
|
||||
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
||||
|
||||
parallel::ParallelParameterContextInit(func_graph);
|
||||
|
||||
// suppose that there is not KeywordArgument for the top graph
|
||||
// get the hyper parameter
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
|
@ -224,7 +227,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
if (param_node->has_default()) {
|
||||
AbstractBasePtr ptr =
|
||||
abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true);
|
||||
|
||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||
args_spec.push_back(ptr);
|
||||
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr);
|
||||
}
|
||||
}
|
||||
// Analyze
|
||||
|
|
|
@ -379,7 +379,7 @@ class _Executor:
|
|||
|
||||
self._params_init_data(obj, params)
|
||||
if not enable_debug_runtime or enable_ge:
|
||||
if auto_parallel_mode:
|
||||
if auto_parallel_mode and "train" in phase:
|
||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||
obj.load_parameter_slice(params)
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_get_parameter_layout():
|
|||
net = Net(strategy1, strategy2, weight)
|
||||
net.set_auto_parallel()
|
||||
exe = me._executor
|
||||
exe.compile(net, x, auto_parallel_mode=True)
|
||||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.neg = P.Neg().set_strategy(strategy2)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out = self.neg(out)
|
||||
return out
|
||||
|
||||
|
||||
class EvalNet(Cell):
|
||||
def __init__(self, network, strategy2=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.relu = P.ReLU().set_strategy(strategy2)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.network(x, b)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def test_train_and_eval():
|
||||
context.set_context(save_graphs=True, mode=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16)
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = ((4, 4), )
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
eval_net = EvalNet(net, strategy2=strategy2)
|
||||
net.set_train()
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
|
||||
|
||||
eval_net.set_train(mode=False)
|
||||
eval_net.set_auto_parallel()
|
||||
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
|
||||
|
||||
context.reset_auto_parallel_context()
|
Loading…
Reference in New Issue