!3907 modify ckpt func check parameter
Merge pull request !3907 from changzherui/mod_ckpt_func_param
This commit is contained in:
commit
3dcea81721
|
@ -108,13 +108,13 @@ class CheckpointConfig:
|
|||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
raise ValueError("The input_param can't be all None or 0")
|
||||
|
||||
if save_checkpoint_steps:
|
||||
if save_checkpoint_steps is not None:
|
||||
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
|
||||
if save_checkpoint_seconds:
|
||||
if save_checkpoint_seconds is not None:
|
||||
save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
|
||||
if keep_checkpoint_max:
|
||||
if keep_checkpoint_max is not None:
|
||||
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
||||
if keep_checkpoint_per_n_minutes:
|
||||
if keep_checkpoint_per_n_minutes is not None:
|
||||
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
||||
|
||||
self._save_checkpoint_steps = save_checkpoint_steps
|
||||
|
|
|
@ -258,7 +258,7 @@ def load_checkpoint(ckpt_file_name, net=None):
|
|||
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
|
||||
if net:
|
||||
if net is not None:
|
||||
load_param_into_net(net, parameter_dict)
|
||||
|
||||
return parameter_dict
|
||||
|
|
Loading…
Reference in New Issue