wide_and_deep merge ckpt in eval

This commit is contained in:
yao_yf 2020-08-24 15:59:54 +08:00
parent 816ed8d842
commit eeede168fa
10 changed files with 71 additions and 19 deletions

View File

@ -343,7 +343,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
if (!param_names.empty()) {
strategy_key_name = param_names[0].first;
strategy_key_name = prim->name() + "_" + param_names[0].first;
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();

View File

@ -1528,7 +1528,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
std::string strategy_key_name = "";
auto param_names = NodeParameterName(cnode);
if (!param_names.empty()) {
strategy_key_name = param_names[0].first;
strategy_key_name = prim->name() + "_" + param_names[0].first;
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
@ -2219,24 +2219,35 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
auto input = node_inputs[i];
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
if (ParameterRequireGrad(input_parameter)) {
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
param_names.push_back({input_parameter->name(), i});
}
} else if (input->isa<CNode>()) {
CNodePtr cnode = input->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return param_names;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == CAST && cnode->inputs().size() >= 1) {
auto cast_input = cnode->inputs()[1];
if (cast_input->isa<Parameter>()) {
auto cast_input_parameter = cast_input->cast<ParameterPtr>();
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) {
param_names.push_back({cast_input_parameter->name(), i});
}
}
}
}
}
return param_names;
}
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
StrategyMap stra_map;
TensorInfoMap tensor_info_map;
ManualShapeMap manual_shape_map;
auto ret = func_graph->get_return();
auto all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
@ -2258,7 +2269,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
stra_map[param_name] = strategyPtr;
std::string stratey_key_name = prim->name() + "_" + param_name;
stra_map[stratey_key_name] = strategyPtr;
for (auto param_name_pair : param_names) {
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
continue;
@ -2552,7 +2564,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(root);
CheckpointStrategy(all_nodes);
}
HandleSymbolicKeyInstance(root, all_nodes);

View File

@ -136,7 +136,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node);
void CheckpointStrategy(const FuncGraphPtr &func_graph);
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
// main step of Parallel
bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);

View File

@ -152,7 +152,11 @@ optional arguments:
--keep_prob The keep rate in dropout layer.(Default:1.0)
--dropout_flag Enable dropout.(Default:0)
--output_path Deprecated
--ckpt_path The location of the checkpoint file.(Defalut:./checkpoints/)
--ckpt_path The location of the checkpoint file. If the checkpoint file
is a slice of weight, multiple checkpoint files need to be
transferred. Use ';' to separate them and sort them in sequence
like "./checkpoints/0.ckpt;./checkpoints/1.ckpt".
(Defalut:./checkpoints/)
--eval_file_name Eval output file.(Default:eval.og)
--loss_file_name Loss output file.(Default:loss.log)
--host_device_mix Enable host device mode or not.(Default:0)

View File

@ -18,7 +18,8 @@
import os
from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint, load_param_into_net,\
build_searched_strategy, merge_sliced_parameter
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
@ -81,8 +82,28 @@ def test_eval(config):
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
param_dict = load_checkpoint(config.ckpt_path)
ckpt_path = config.ckpt_path
if ";" in ckpt_path:
ckpt_paths = ckpt_path.split(';')
param_list_dict = {}
strategy = build_searched_strategy(config.stra_ckpt)
for slice_path in ckpt_paths:
param_slice_dict = load_checkpoint(slice_path)
for key, value in param_slice_dict.items():
if 'optimizer' in key:
continue
if key not in param_list_dict:
param_list_dict[key] = []
param_list_dict[key].append(value)
param_dict = {}
for key, value in param_list_dict.items():
if key in strategy:
merged_parameter = merge_sliced_parameter(value, strategy)
else:
merged_parameter = merge_sliced_parameter(value)
param_dict[key] = merged_parameter
else:
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(eval_net, param_dict)
auc_metric = AUCMetric()

View File

@ -97,6 +97,7 @@ class EvalCallBack(Callback):
self.eval_file_name = config.eval_file_name
self.eval_values = []
self.host_device_mix = host_device_mix
self.config = config
def epoch_end(self, run_context):
"""
@ -106,7 +107,7 @@ class EvalCallBack(Callback):
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
strategy_ckpt_load_file=self.config.stra_ckpt)
rank_id = 0
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
ParallelMode.DATA_PARALLEL):

View File

@ -39,6 +39,8 @@ def argparse_init():
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
help="The strategy checkpoint file.")
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
@ -75,6 +77,7 @@ class WideDeepConfig():
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.stra_ckpt = './checkpoints/strategy.ckpt'
self.host_device_mix = 0
self.dataset_type = "tfrecord"
self.parameter_server = 0
@ -107,6 +110,7 @@ class WideDeepConfig():
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path
self.stra_ckpt = args.stra_ckpt
self.host_device_mix = args.host_device_mix
self.dataset_type = args.dataset_type
self.parameter_server = args.parameter_server

View File

@ -203,6 +203,7 @@ class WideDeepModel(nn.Cell):
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size)
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
@ -211,6 +212,10 @@ class WideDeepModel(nn.Cell):
self.deep_reshape.add_prim_attr("skip_redistribution", True)
self.reduce_sum.add_prim_attr("cross_batch", True)
self.embedding_table = self.deep_embeddinglookup.embedding_table
elif host_device_mix:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
self.embedding_table = self.deep_embeddinglookup.embedding_table
elif parameter_server:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)

View File

@ -111,10 +111,11 @@ def train_and_eval(config):
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
keep_checkpoint_max=5, integrated_save=False)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if not host_device_mix:
callback_list.append(ckpoint_cb)

View File

@ -30,6 +30,8 @@ def argparse_init():
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
parser.add_argument("--deep_layer_act", type=str, default='relu')
parser.add_argument("--keep_prob", type=float, default=1.0)
parser.add_argument("--stra_ckpt", type=str, default="./strategy_train.ckpt",
help="The strategy checkpoint file.")
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
@ -63,6 +65,7 @@ class WideDeepConfig():
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.stra_ckpt = "./strategy_train.ckpt"
def argparse_init(self):
"""
@ -90,3 +93,4 @@ class WideDeepConfig():
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path
self.stra_ckpt = args.stra_ckpt