!4482 add_field_in_strategy_ckpt

Merge pull request !4482 from yao_yf/add_field_in_strategy_ckpt
This commit is contained in:
mindspore-ci-bot 2020-08-17 10:11:40 +08:00 committed by Gitee
commit f384444487
3 changed files with 3 additions and 5 deletions

View File

@ -129,6 +129,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second);
}
parallel_layouts->set_field(tensor_layout.get_field_size());
}
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);

View File

@ -53,6 +53,7 @@ message ParallelLayouts {
repeated TensorMap tensor_map = 2;
repeated ParamSplitShape param_split_shape = 3;
repeated IndicesOffset indices_offset = 4;
required int32 field = 5;
}
message ParallelLayoutItem {

View File

@ -161,13 +161,9 @@ class WideDeepModel(nn.Cell):
self.layer_dims = self.deep_layer_dims_list + [1]
self.all_dim_list = [self.deep_input_dims] + self.layer_dims
init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init),
('V_l2', [self.vocab_size, self.emb_dim], self.emb_init),
('Wide_b', [1], self.emb_init)]
init_acts = [('Wide_b', [1], self.emb_init)]
var_map = init_var_dict(self.init_args, init_acts)
self.wide_w = var_map["Wide_w"]
self.wide_b = var_map["Wide_b"]
self.embedding_table = var_map["V_l2"]
if parameter_server:
self.wide_w.set_param_ps()
self.embedding_table.set_param_ps()