forked from mindspore-Ecosystem/mindspore
!49370 ckpt参数更新接口load_param_into_net新增返回值:ckpt_not_load
Merge pull request !49370 from GuoZhibin/add_return_value_ckpt_not_load
This commit is contained in:
commit
53714e4a21
|
@ -12,6 +12,7 @@ mindspore.load_param_into_net
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
List,网络中没有被加载的参数。
|
List,网络中没有被加载的参数。
|
||||||
|
List,checkpoint文件中没有被加载的参数。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - 如果参数不是Cell,或者 `parameter_dict` 不是Parameter类型的字典。
|
- **TypeError** - 如果参数不是Cell,或者 `parameter_dict` 不是Parameter类型的字典。
|
||||||
|
|
|
@ -963,7 +963,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
on the parameters of the same type, such as float32 to float16. Default: False.
|
on the parameters of the same type, such as float32 to float16. Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List, the parameter name which are not loaded into the network.
|
List, the parameter name in model which are not loaded into the network.
|
||||||
|
List, the parameter name in checkpoint file which are not loaded into the network.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||||
|
@ -999,10 +1000,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
logger.info("Execute the process of loading parameters into net.")
|
logger.info("Execute the process of loading parameters into net.")
|
||||||
net.init_parameters_data()
|
net.init_parameters_data()
|
||||||
param_not_load = []
|
param_not_load = []
|
||||||
|
ckpt_not_load = list(parameter_dict.keys())
|
||||||
for _, param in net.parameters_and_names():
|
for _, param in net.parameters_and_names():
|
||||||
if param.name in parameter_dict:
|
if param.name in parameter_dict:
|
||||||
new_param = copy.deepcopy(parameter_dict[param.name])
|
new_param = copy.deepcopy(parameter_dict[param.name])
|
||||||
_update_param(param, new_param, strict_load)
|
_update_param(param, new_param, strict_load)
|
||||||
|
ckpt_not_load.remove(param.name)
|
||||||
else:
|
else:
|
||||||
param_not_load.append(param.name)
|
param_not_load.append(param.name)
|
||||||
|
|
||||||
|
@ -1021,7 +1024,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
"when training and loading checkpoint.".format(len(param_not_load)))
|
"when training and loading checkpoint.".format(len(param_not_load)))
|
||||||
for param_name in param_not_load:
|
for param_name in param_not_load:
|
||||||
logger.warning("{} is not loaded.".format(param_name))
|
logger.warning("{} is not loaded.".format(param_name))
|
||||||
return param_not_load
|
return param_not_load, ckpt_not_load
|
||||||
|
|
||||||
|
|
||||||
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
||||||
|
|
Loading…
Reference in New Issue