forked from mindspore-Ecosystem/mindspore
wide_and_deep merge ckpt in eval
This commit is contained in:
parent
816ed8d842
commit
eeede168fa
|
@ -343,7 +343,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
||||||
std::string strategy_key_name = "";
|
std::string strategy_key_name = "";
|
||||||
auto param_names = NodeParameterName(cnode);
|
auto param_names = NodeParameterName(cnode);
|
||||||
if (!param_names.empty()) {
|
if (!param_names.empty()) {
|
||||||
strategy_key_name = param_names[0].first;
|
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
||||||
}
|
}
|
||||||
bool load_strategy_from_ckpt =
|
bool load_strategy_from_ckpt =
|
||||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
|
||||||
|
|
|
@ -1528,7 +1528,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
std::string strategy_key_name = "";
|
std::string strategy_key_name = "";
|
||||||
auto param_names = NodeParameterName(cnode);
|
auto param_names = NodeParameterName(cnode);
|
||||||
if (!param_names.empty()) {
|
if (!param_names.empty()) {
|
||||||
strategy_key_name = param_names[0].first;
|
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
||||||
}
|
}
|
||||||
bool load_strategy_from_ckpt =
|
bool load_strategy_from_ckpt =
|
||||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
||||||
|
@ -2219,9 +2219,23 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
||||||
auto input = node_inputs[i];
|
auto input = node_inputs[i];
|
||||||
if (input->isa<Parameter>()) {
|
if (input->isa<Parameter>()) {
|
||||||
auto input_parameter = input->cast<ParameterPtr>();
|
auto input_parameter = input->cast<ParameterPtr>();
|
||||||
if (input_parameter->has_default()) {
|
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
|
||||||
if (ParameterRequireGrad(input_parameter)) {
|
param_names.push_back({input_parameter->name(), i});
|
||||||
param_names.push_back({input_parameter->name(), i});
|
}
|
||||||
|
} else if (input->isa<CNode>()) {
|
||||||
|
CNodePtr cnode = input->cast<CNodePtr>();
|
||||||
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||||
|
return param_names;
|
||||||
|
}
|
||||||
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||||
|
if (prim->name() == CAST && cnode->inputs().size() >= 1) {
|
||||||
|
auto cast_input = cnode->inputs()[1];
|
||||||
|
if (cast_input->isa<Parameter>()) {
|
||||||
|
auto cast_input_parameter = cast_input->cast<ParameterPtr>();
|
||||||
|
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) {
|
||||||
|
param_names.push_back({cast_input_parameter->name(), i});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2229,14 +2243,11 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
||||||
return param_names;
|
return param_names;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
||||||
StrategyMap stra_map;
|
StrategyMap stra_map;
|
||||||
TensorInfoMap tensor_info_map;
|
TensorInfoMap tensor_info_map;
|
||||||
ManualShapeMap manual_shape_map;
|
ManualShapeMap manual_shape_map;
|
||||||
auto ret = func_graph->get_return();
|
|
||||||
auto all_nodes = DeepScopedGraphSearch(ret);
|
|
||||||
for (auto &node : all_nodes) {
|
for (auto &node : all_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
@ -2258,7 +2269,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
||||||
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
||||||
StrategyPtr strategyPtr = operator_info->strategy();
|
StrategyPtr strategyPtr = operator_info->strategy();
|
||||||
MS_EXCEPTION_IF_NULL(node->scope());
|
MS_EXCEPTION_IF_NULL(node->scope());
|
||||||
stra_map[param_name] = strategyPtr;
|
std::string stratey_key_name = prim->name() + "_" + param_name;
|
||||||
|
stra_map[stratey_key_name] = strategyPtr;
|
||||||
for (auto param_name_pair : param_names) {
|
for (auto param_name_pair : param_names) {
|
||||||
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -2552,7 +2564,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
|
|
||||||
// save strategy as checkpoint for multi-train
|
// save strategy as checkpoint for multi-train
|
||||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||||
CheckpointStrategy(root);
|
CheckpointStrategy(all_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
HandleSymbolicKeyInstance(root, all_nodes);
|
HandleSymbolicKeyInstance(root, all_nodes);
|
||||||
|
|
|
@ -136,7 +136,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
||||||
|
|
||||||
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node);
|
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node);
|
||||||
|
|
||||||
void CheckpointStrategy(const FuncGraphPtr &func_graph);
|
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||||
|
|
||||||
// main step of Parallel
|
// main step of Parallel
|
||||||
bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);
|
bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);
|
||||||
|
|
|
@ -152,7 +152,11 @@ optional arguments:
|
||||||
--keep_prob The keep rate in dropout layer.(Default:1.0)
|
--keep_prob The keep rate in dropout layer.(Default:1.0)
|
||||||
--dropout_flag Enable dropout.(Default:0)
|
--dropout_flag Enable dropout.(Default:0)
|
||||||
--output_path Deprecated
|
--output_path Deprecated
|
||||||
--ckpt_path The location of the checkpoint file.(Defalut:./checkpoints/)
|
--ckpt_path The location of the checkpoint file. If the checkpoint file
|
||||||
|
is a slice of weight, multiple checkpoint files need to be
|
||||||
|
transferred. Use ';' to separate them and sort them in sequence
|
||||||
|
like "./checkpoints/0.ckpt;./checkpoints/1.ckpt".
|
||||||
|
(Defalut:./checkpoints/)
|
||||||
--eval_file_name Eval output file.(Default:eval.og)
|
--eval_file_name Eval output file.(Default:eval.og)
|
||||||
--loss_file_name Loss output file.(Default:loss.log)
|
--loss_file_name Loss output file.(Default:loss.log)
|
||||||
--host_device_mix Enable host device mode or not.(Default:0)
|
--host_device_mix Enable host device mode or not.(Default:0)
|
||||||
|
|
|
@ -18,7 +18,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net,\
|
||||||
|
build_searched_strategy, merge_sliced_parameter
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
|
@ -81,8 +82,28 @@ def test_eval(config):
|
||||||
|
|
||||||
net_builder = ModelBuilder()
|
net_builder = ModelBuilder()
|
||||||
train_net, eval_net = net_builder.get_net(config)
|
train_net, eval_net = net_builder.get_net(config)
|
||||||
|
ckpt_path = config.ckpt_path
|
||||||
param_dict = load_checkpoint(config.ckpt_path)
|
if ";" in ckpt_path:
|
||||||
|
ckpt_paths = ckpt_path.split(';')
|
||||||
|
param_list_dict = {}
|
||||||
|
strategy = build_searched_strategy(config.stra_ckpt)
|
||||||
|
for slice_path in ckpt_paths:
|
||||||
|
param_slice_dict = load_checkpoint(slice_path)
|
||||||
|
for key, value in param_slice_dict.items():
|
||||||
|
if 'optimizer' in key:
|
||||||
|
continue
|
||||||
|
if key not in param_list_dict:
|
||||||
|
param_list_dict[key] = []
|
||||||
|
param_list_dict[key].append(value)
|
||||||
|
param_dict = {}
|
||||||
|
for key, value in param_list_dict.items():
|
||||||
|
if key in strategy:
|
||||||
|
merged_parameter = merge_sliced_parameter(value, strategy)
|
||||||
|
else:
|
||||||
|
merged_parameter = merge_sliced_parameter(value)
|
||||||
|
param_dict[key] = merged_parameter
|
||||||
|
else:
|
||||||
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
load_param_into_net(eval_net, param_dict)
|
load_param_into_net(eval_net, param_dict)
|
||||||
|
|
||||||
auc_metric = AUCMetric()
|
auc_metric = AUCMetric()
|
||||||
|
|
|
@ -97,6 +97,7 @@ class EvalCallBack(Callback):
|
||||||
self.eval_file_name = config.eval_file_name
|
self.eval_file_name = config.eval_file_name
|
||||||
self.eval_values = []
|
self.eval_values = []
|
||||||
self.host_device_mix = host_device_mix
|
self.host_device_mix = host_device_mix
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
def epoch_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -106,7 +107,7 @@ class EvalCallBack(Callback):
|
||||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
||||||
strategy_ckpt_load_file="./strategy_train.ckpt")
|
strategy_ckpt_load_file=self.config.stra_ckpt)
|
||||||
rank_id = 0
|
rank_id = 0
|
||||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||||
ParallelMode.DATA_PARALLEL):
|
ParallelMode.DATA_PARALLEL):
|
||||||
|
|
|
@ -39,6 +39,8 @@ def argparse_init():
|
||||||
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
|
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
|
||||||
parser.add_argument("--output_path", type=str, default="./output/")
|
parser.add_argument("--output_path", type=str, default="./output/")
|
||||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
|
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
|
||||||
|
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
|
||||||
|
help="The strategy checkpoint file.")
|
||||||
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
|
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
|
||||||
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
|
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
|
||||||
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
|
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
|
||||||
|
@ -75,6 +77,7 @@ class WideDeepConfig():
|
||||||
self.eval_file_name = "eval.log"
|
self.eval_file_name = "eval.log"
|
||||||
self.loss_file_name = "loss.log"
|
self.loss_file_name = "loss.log"
|
||||||
self.ckpt_path = "./checkpoints/"
|
self.ckpt_path = "./checkpoints/"
|
||||||
|
self.stra_ckpt = './checkpoints/strategy.ckpt'
|
||||||
self.host_device_mix = 0
|
self.host_device_mix = 0
|
||||||
self.dataset_type = "tfrecord"
|
self.dataset_type = "tfrecord"
|
||||||
self.parameter_server = 0
|
self.parameter_server = 0
|
||||||
|
@ -107,6 +110,7 @@ class WideDeepConfig():
|
||||||
self.eval_file_name = args.eval_file_name
|
self.eval_file_name = args.eval_file_name
|
||||||
self.loss_file_name = args.loss_file_name
|
self.loss_file_name = args.loss_file_name
|
||||||
self.ckpt_path = args.ckpt_path
|
self.ckpt_path = args.ckpt_path
|
||||||
|
self.stra_ckpt = args.stra_ckpt
|
||||||
self.host_device_mix = args.host_device_mix
|
self.host_device_mix = args.host_device_mix
|
||||||
self.dataset_type = args.dataset_type
|
self.dataset_type = args.dataset_type
|
||||||
self.parameter_server = args.parameter_server
|
self.parameter_server = args.parameter_server
|
||||||
|
|
|
@ -203,6 +203,7 @@ class WideDeepModel(nn.Cell):
|
||||||
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||||
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
|
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
|
||||||
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||||
|
self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size)
|
||||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||||
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
|
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
|
||||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
||||||
|
@ -211,6 +212,10 @@ class WideDeepModel(nn.Cell):
|
||||||
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
||||||
self.reduce_sum.add_prim_attr("cross_batch", True)
|
self.reduce_sum.add_prim_attr("cross_batch", True)
|
||||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||||
|
elif host_device_mix:
|
||||||
|
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||||
|
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||||
|
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||||
elif parameter_server:
|
elif parameter_server:
|
||||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||||
|
|
|
@ -111,10 +111,11 @@ def train_and_eval(config):
|
||||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
|
||||||
|
|
||||||
callback = LossCallBack(config=config, per_print_times=20)
|
callback = LossCallBack(config=config, per_print_times=20)
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||||
|
keep_checkpoint_max=5, integrated_save=False)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||||
directory=config.ckpt_path, config=ckptconfig)
|
directory=config.ckpt_path, config=ckptconfig)
|
||||||
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
|
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
|
||||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||||
if not host_device_mix:
|
if not host_device_mix:
|
||||||
callback_list.append(ckpoint_cb)
|
callback_list.append(ckpoint_cb)
|
||||||
|
|
|
@ -30,6 +30,8 @@ def argparse_init():
|
||||||
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
|
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
|
||||||
parser.add_argument("--deep_layer_act", type=str, default='relu')
|
parser.add_argument("--deep_layer_act", type=str, default='relu')
|
||||||
parser.add_argument("--keep_prob", type=float, default=1.0)
|
parser.add_argument("--keep_prob", type=float, default=1.0)
|
||||||
|
parser.add_argument("--stra_ckpt", type=str, default="./strategy_train.ckpt",
|
||||||
|
help="The strategy checkpoint file.")
|
||||||
|
|
||||||
parser.add_argument("--output_path", type=str, default="./output/")
|
parser.add_argument("--output_path", type=str, default="./output/")
|
||||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
|
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
|
||||||
|
@ -63,6 +65,7 @@ class WideDeepConfig():
|
||||||
self.eval_file_name = "eval.log"
|
self.eval_file_name = "eval.log"
|
||||||
self.loss_file_name = "loss.log"
|
self.loss_file_name = "loss.log"
|
||||||
self.ckpt_path = "./checkpoints/"
|
self.ckpt_path = "./checkpoints/"
|
||||||
|
self.stra_ckpt = "./strategy_train.ckpt"
|
||||||
|
|
||||||
def argparse_init(self):
|
def argparse_init(self):
|
||||||
"""
|
"""
|
||||||
|
@ -90,3 +93,4 @@ class WideDeepConfig():
|
||||||
self.eval_file_name = args.eval_file_name
|
self.eval_file_name = args.eval_file_name
|
||||||
self.loss_file_name = args.loss_file_name
|
self.loss_file_name = args.loss_file_name
|
||||||
self.ckpt_path = args.ckpt_path
|
self.ckpt_path = args.ckpt_path
|
||||||
|
self.stra_ckpt = args.stra_ckpt
|
||||||
|
|
Loading…
Reference in New Issue