forked from mindspore-Ecosystem/mindspore
ckpt and restore parameter shape
This commit is contained in:
parent
311b7e71af
commit
6b54a6417d
|
@ -22,12 +22,15 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
#include "parallel/device_manager.h"
|
#include "parallel/device_manager.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
static std::map<std::string, std::vector<int>> param_shapes;
|
||||||
|
|
||||||
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
|
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
|
||||||
AUTO_PARALLEL};
|
AUTO_PARALLEL};
|
||||||
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
||||||
|
@ -136,5 +139,56 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||||
|
void ParallelParameterContextInit(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
param_shapes.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
|
||||||
|
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||||
|
AbstractBasePtr ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(param_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
|
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) ||
|
||||||
|
func_graph->flags()[TRAINING]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = param_shapes.find(param_node->name());
|
||||||
|
if (iter == param_shapes.end()) {
|
||||||
|
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<int> shape = iter->second;
|
||||||
|
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
|
||||||
|
ptr->set_shape(base_shape);
|
||||||
|
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
|
||||||
|
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||||
|
const AbstractBasePtr &ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(param_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
|
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
|
||||||
|
auto ret = param_shapes.try_emplace(param_node->name(), shape);
|
||||||
|
if (!ret.second) {
|
||||||
|
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,6 +26,9 @@
|
||||||
#include "parallel/ops_info/ops_utils.h"
|
#include "parallel/ops_info/ops_utils.h"
|
||||||
#include "parallel/status.h"
|
#include "parallel/status.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "debug/info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
@ -38,6 +41,8 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
|
||||||
constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
|
constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
|
||||||
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
|
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
|
||||||
|
|
||||||
|
constexpr char TRAINING[] = "training";
|
||||||
|
|
||||||
class ParallelContext {
|
class ParallelContext {
|
||||||
public:
|
public:
|
||||||
~ParallelContext() = default;
|
~ParallelContext() = default;
|
||||||
|
@ -114,6 +119,12 @@ class ParallelContext {
|
||||||
std::string strategy_ckpt_load_file_;
|
std::string strategy_ckpt_load_file_;
|
||||||
std::string strategy_ckpt_save_file_;
|
std::string strategy_ckpt_save_file_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
|
||||||
|
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||||
|
AbstractBasePtr ptr);
|
||||||
|
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||||
|
const AbstractBasePtr &ptr);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
#include "parallel/costmodel_context.h"
|
#include "parallel/costmodel_context.h"
|
||||||
|
#include "parallel/context.h"
|
||||||
#include "pipeline/pass.h"
|
#include "pipeline/pass.h"
|
||||||
#include "pipeline/parse/parse_base.h"
|
#include "pipeline/parse/parse_base.h"
|
||||||
#include "pipeline/parse/data_converter.h"
|
#include "pipeline/parse/data_converter.h"
|
||||||
|
@ -217,6 +218,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||||
FuncGraphPtr func_graph = res->func_graph();
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
||||||
|
|
||||||
|
parallel::ParallelParameterContextInit(func_graph);
|
||||||
|
|
||||||
// suppose that there is not KeywordArgument for the top graph
|
// suppose that there is not KeywordArgument for the top graph
|
||||||
// get the hyper parameter
|
// get the hyper parameter
|
||||||
for (const auto ¶m : func_graph->parameters()) {
|
for (const auto ¶m : func_graph->parameters()) {
|
||||||
|
@ -224,7 +227,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||||
if (param_node->has_default()) {
|
if (param_node->has_default()) {
|
||||||
AbstractBasePtr ptr =
|
AbstractBasePtr ptr =
|
||||||
abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true);
|
abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true);
|
||||||
|
|
||||||
|
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||||
args_spec.push_back(ptr);
|
args_spec.push_back(ptr);
|
||||||
|
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Analyze
|
// Analyze
|
||||||
|
|
|
@ -379,7 +379,7 @@ class _Executor:
|
||||||
|
|
||||||
self._params_init_data(obj, params)
|
self._params_init_data(obj, params)
|
||||||
if not enable_debug_runtime or enable_ge:
|
if not enable_debug_runtime or enable_ge:
|
||||||
if auto_parallel_mode:
|
if auto_parallel_mode and "train" in phase:
|
||||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||||
obj.load_parameter_slice(params)
|
obj.load_parameter_slice(params)
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ def test_get_parameter_layout():
|
||||||
net = Net(strategy1, strategy2, weight)
|
net = Net(strategy1, strategy2, weight)
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
exe = me._executor
|
exe = me._executor
|
||||||
exe.compile(net, x, auto_parallel_mode=True)
|
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||||
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||||
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import context, Tensor, Parameter
|
||||||
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mul = P.Mul().set_strategy(strategy1)
|
||||||
|
self.neg = P.Neg().set_strategy(strategy2)
|
||||||
|
self.mul_weight = Parameter(mul_weight, "w1")
|
||||||
|
|
||||||
|
def construct(self, x, b):
|
||||||
|
out = self.mul(x, self.mul_weight)
|
||||||
|
out = self.neg(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class EvalNet(Cell):
|
||||||
|
def __init__(self, network, strategy2=None):
|
||||||
|
super().__init__()
|
||||||
|
self.network = network
|
||||||
|
self.relu = P.ReLU().set_strategy(strategy2)
|
||||||
|
|
||||||
|
def construct(self, x, b):
|
||||||
|
out = self.network(x, b)
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_x = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||||
|
_w1 = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||||
|
_b = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_and_eval():
|
||||||
|
context.set_context(save_graphs=True, mode=0)
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16)
|
||||||
|
strategy1 = ((4, 4), (4, 4))
|
||||||
|
strategy2 = ((4, 4), )
|
||||||
|
net = Net(_w1, strategy1, strategy2)
|
||||||
|
eval_net = EvalNet(net, strategy2=strategy2)
|
||||||
|
net.set_train()
|
||||||
|
net.set_auto_parallel()
|
||||||
|
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
|
||||||
|
|
||||||
|
eval_net.set_train(mode=False)
|
||||||
|
eval_net.set_auto_parallel()
|
||||||
|
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
|
||||||
|
|
||||||
|
context.reset_auto_parallel_context()
|
Loading…
Reference in New Issue