forked from mindspore-Ecosystem/mindspore
!4482 add_field_in_strategy_ckpt
Merge pull request !4482 from yao_yf/add_field_in_strategy_ckpt
This commit is contained in:
commit
f384444487
|
@ -129,6 +129,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
|||
param_split_shape->add_dim(dim_pair.first);
|
||||
indices_offset->add_dim(dim_pair.second);
|
||||
}
|
||||
parallel_layouts->set_field(tensor_layout.get_field_size());
|
||||
}
|
||||
|
||||
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
|
|
|
@ -53,6 +53,7 @@ message ParallelLayouts {
|
|||
repeated TensorMap tensor_map = 2;
|
||||
repeated ParamSplitShape param_split_shape = 3;
|
||||
repeated IndicesOffset indices_offset = 4;
|
||||
required int32 field = 5;
|
||||
}
|
||||
|
||||
message ParallelLayoutItem {
|
||||
|
|
|
@ -161,13 +161,9 @@ class WideDeepModel(nn.Cell):
|
|||
self.layer_dims = self.deep_layer_dims_list + [1]
|
||||
self.all_dim_list = [self.deep_input_dims] + self.layer_dims
|
||||
|
||||
init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init),
|
||||
('V_l2', [self.vocab_size, self.emb_dim], self.emb_init),
|
||||
('Wide_b', [1], self.emb_init)]
|
||||
init_acts = [('Wide_b', [1], self.emb_init)]
|
||||
var_map = init_var_dict(self.init_args, init_acts)
|
||||
self.wide_w = var_map["Wide_w"]
|
||||
self.wide_b = var_map["Wide_b"]
|
||||
self.embedding_table = var_map["V_l2"]
|
||||
if parameter_server:
|
||||
self.wide_w.set_param_ps()
|
||||
self.embedding_table.set_param_ps()
|
||||
|
|
Loading…
Reference in New Issue