add return value: ckpt_not_load
This commit is contained in:
parent
4a8d98014d
commit
9632ee3059
|
@ -12,6 +12,7 @@ mindspore.load_param_into_net
|
|||
|
||||
返回:
|
||||
List,网络中没有被加载的参数。
|
||||
List,checkpoint文件中没有被加载的参数。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数不是Cell,或者 `parameter_dict` 不是Parameter类型的字典。
|
||||
|
|
|
@ -958,7 +958,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
on the parameters of the same type, such as float32 to float16. Default: False.
|
||||
|
||||
Returns:
|
||||
List, the parameter name which are not loaded into the network.
|
||||
List, the parameter name in model which are not loaded into the network.
|
||||
List, the parameter name in checkpoint file which are not loaded into the network.
|
||||
|
||||
Raises:
|
||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||
|
@ -994,10 +995,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
logger.info("Execute the process of loading parameters into net.")
|
||||
net.init_parameters_data()
|
||||
param_not_load = []
|
||||
ckpt_not_load = list(parameter_dict.keys())
|
||||
for _, param in net.parameters_and_names():
|
||||
if param.name in parameter_dict:
|
||||
new_param = copy.deepcopy(parameter_dict[param.name])
|
||||
_update_param(param, new_param, strict_load)
|
||||
ckpt_not_load.remove(param.name)
|
||||
else:
|
||||
param_not_load.append(param.name)
|
||||
|
||||
|
@ -1016,7 +1019,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
"when training and loading checkpoint.".format(len(param_not_load)))
|
||||
for param_name in param_not_load:
|
||||
logger.warning("{} is not loaded.".format(param_name))
|
||||
return param_not_load
|
||||
return param_not_load, ckpt_not_load
|
||||
|
||||
|
||||
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
||||
|
|
Loading…
Reference in New Issue