forked from mindspore-Ecosystem/mindspore
wide_and_deep merge ckpt in eval
This commit is contained in:
parent
816ed8d842
commit
eeede168fa
|
@ -343,7 +343,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
std::string strategy_key_name = "";
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
if (!param_names.empty()) {
|
||||
strategy_key_name = param_names[0].first;
|
||||
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
||||
}
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
|
||||
|
|
|
@ -1528,7 +1528,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
std::string strategy_key_name = "";
|
||||
auto param_names = NodeParameterName(cnode);
|
||||
if (!param_names.empty()) {
|
||||
strategy_key_name = param_names[0].first;
|
||||
strategy_key_name = prim->name() + "_" + param_names[0].first;
|
||||
}
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
||||
|
@ -2219,9 +2219,23 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
|||
auto input = node_inputs[i];
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default()) {
|
||||
if (ParameterRequireGrad(input_parameter)) {
|
||||
param_names.push_back({input_parameter->name(), i});
|
||||
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
|
||||
param_names.push_back({input_parameter->name(), i});
|
||||
}
|
||||
} else if (input->isa<CNode>()) {
|
||||
CNodePtr cnode = input->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return param_names;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
if (prim->name() == CAST && cnode->inputs().size() >= 1) {
|
||||
auto cast_input = cnode->inputs()[1];
|
||||
if (cast_input->isa<Parameter>()) {
|
||||
auto cast_input_parameter = cast_input->cast<ParameterPtr>();
|
||||
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) {
|
||||
param_names.push_back({cast_input_parameter->name(), i});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2229,14 +2243,11 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node)
|
|||
return param_names;
|
||||
}
|
||||
|
||||
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_LOG(DEBUG) << "Save strategy to checkpoint begin";
|
||||
StrategyMap stra_map;
|
||||
TensorInfoMap tensor_info_map;
|
||||
ManualShapeMap manual_shape_map;
|
||||
auto ret = func_graph->get_return();
|
||||
auto all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -2258,7 +2269,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|||
std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info();
|
||||
StrategyPtr strategyPtr = operator_info->strategy();
|
||||
MS_EXCEPTION_IF_NULL(node->scope());
|
||||
stra_map[param_name] = strategyPtr;
|
||||
std::string stratey_key_name = prim->name() + "_" + param_name;
|
||||
stra_map[stratey_key_name] = strategyPtr;
|
||||
for (auto param_name_pair : param_names) {
|
||||
if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) {
|
||||
continue;
|
||||
|
@ -2552,7 +2564,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
|
||||
// save strategy as checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||
CheckpointStrategy(root);
|
||||
CheckpointStrategy(all_nodes);
|
||||
}
|
||||
|
||||
HandleSymbolicKeyInstance(root, all_nodes);
|
||||
|
|
|
@ -136,7 +136,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
|
||||
std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node);
|
||||
|
||||
void CheckpointStrategy(const FuncGraphPtr &func_graph);
|
||||
void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
||||
// main step of Parallel
|
||||
bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);
|
||||
|
|
|
@ -152,7 +152,11 @@ optional arguments:
|
|||
--keep_prob The keep rate in dropout layer.(Default:1.0)
|
||||
--dropout_flag Enable dropout.(Default:0)
|
||||
--output_path Deprecated
|
||||
--ckpt_path The location of the checkpoint file.(Defalut:./checkpoints/)
|
||||
--ckpt_path The location of the checkpoint file. If the checkpoint file
|
||||
is a slice of weight, multiple checkpoint files need to be
|
||||
transferred. Use ';' to separate them and sort them in sequence
|
||||
like "./checkpoints/0.ckpt;./checkpoints/1.ckpt".
|
||||
(Defalut:./checkpoints/)
|
||||
--eval_file_name Eval output file.(Default:eval.og)
|
||||
--loss_file_name Loss output file.(Default:loss.log)
|
||||
--host_device_mix Enable host device mode or not.(Default:0)
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
import os
|
||||
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net,\
|
||||
build_searched_strategy, merge_sliced_parameter
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
|
@ -81,8 +82,28 @@ def test_eval(config):
|
|||
|
||||
net_builder = ModelBuilder()
|
||||
train_net, eval_net = net_builder.get_net(config)
|
||||
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
ckpt_path = config.ckpt_path
|
||||
if ";" in ckpt_path:
|
||||
ckpt_paths = ckpt_path.split(';')
|
||||
param_list_dict = {}
|
||||
strategy = build_searched_strategy(config.stra_ckpt)
|
||||
for slice_path in ckpt_paths:
|
||||
param_slice_dict = load_checkpoint(slice_path)
|
||||
for key, value in param_slice_dict.items():
|
||||
if 'optimizer' in key:
|
||||
continue
|
||||
if key not in param_list_dict:
|
||||
param_list_dict[key] = []
|
||||
param_list_dict[key].append(value)
|
||||
param_dict = {}
|
||||
for key, value in param_list_dict.items():
|
||||
if key in strategy:
|
||||
merged_parameter = merge_sliced_parameter(value, strategy)
|
||||
else:
|
||||
merged_parameter = merge_sliced_parameter(value)
|
||||
param_dict[key] = merged_parameter
|
||||
else:
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
|
||||
auc_metric = AUCMetric()
|
||||
|
|
|
@ -97,6 +97,7 @@ class EvalCallBack(Callback):
|
|||
self.eval_file_name = config.eval_file_name
|
||||
self.eval_values = []
|
||||
self.host_device_mix = host_device_mix
|
||||
self.config = config
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
|
@ -106,7 +107,7 @@ class EvalCallBack(Callback):
|
|||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
||||
strategy_ckpt_load_file="./strategy_train.ckpt")
|
||||
strategy_ckpt_load_file=self.config.stra_ckpt)
|
||||
rank_id = 0
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
|
|
|
@ -39,6 +39,8 @@ def argparse_init():
|
|||
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
|
||||
parser.add_argument("--output_path", type=str, default="./output/")
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
|
||||
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
|
||||
help="The strategy checkpoint file.")
|
||||
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
|
||||
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
|
||||
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
|
||||
|
@ -75,6 +77,7 @@ class WideDeepConfig():
|
|||
self.eval_file_name = "eval.log"
|
||||
self.loss_file_name = "loss.log"
|
||||
self.ckpt_path = "./checkpoints/"
|
||||
self.stra_ckpt = './checkpoints/strategy.ckpt'
|
||||
self.host_device_mix = 0
|
||||
self.dataset_type = "tfrecord"
|
||||
self.parameter_server = 0
|
||||
|
@ -107,6 +110,7 @@ class WideDeepConfig():
|
|||
self.eval_file_name = args.eval_file_name
|
||||
self.loss_file_name = args.loss_file_name
|
||||
self.ckpt_path = args.ckpt_path
|
||||
self.stra_ckpt = args.stra_ckpt
|
||||
self.host_device_mix = args.host_device_mix
|
||||
self.dataset_type = args.dataset_type
|
||||
self.parameter_server = args.parameter_server
|
||||
|
|
|
@ -203,6 +203,7 @@ class WideDeepModel(nn.Cell):
|
|||
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||
self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size)
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
||||
|
@ -211,6 +212,10 @@ class WideDeepModel(nn.Cell):
|
|||
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
||||
self.reduce_sum.add_prim_attr("cross_batch", True)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
elif host_device_mix:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
elif parameter_server:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
|
|
|
@ -111,10 +111,11 @@ def train_and_eval(config):
|
|||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
|
||||
|
||||
callback = LossCallBack(config=config, per_print_times=20)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||
keep_checkpoint_max=5, integrated_save=False)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if not host_device_mix:
|
||||
callback_list.append(ckpoint_cb)
|
||||
|
|
|
@ -30,6 +30,8 @@ def argparse_init():
|
|||
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
|
||||
parser.add_argument("--deep_layer_act", type=str, default='relu')
|
||||
parser.add_argument("--keep_prob", type=float, default=1.0)
|
||||
parser.add_argument("--stra_ckpt", type=str, default="./strategy_train.ckpt",
|
||||
help="The strategy checkpoint file.")
|
||||
|
||||
parser.add_argument("--output_path", type=str, default="./output/")
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
|
||||
|
@ -63,6 +65,7 @@ class WideDeepConfig():
|
|||
self.eval_file_name = "eval.log"
|
||||
self.loss_file_name = "loss.log"
|
||||
self.ckpt_path = "./checkpoints/"
|
||||
self.stra_ckpt = "./strategy_train.ckpt"
|
||||
|
||||
def argparse_init(self):
|
||||
"""
|
||||
|
@ -90,3 +93,4 @@ class WideDeepConfig():
|
|||
self.eval_file_name = args.eval_file_name
|
||||
self.loss_file_name = args.loss_file_name
|
||||
self.ckpt_path = args.ckpt_path
|
||||
self.stra_ckpt = args.stra_ckpt
|
||||
|
|
Loading…
Reference in New Issue