fix checkpoint evaliaction.

This commit is contained in:
chenzomi 2020-06-25 09:43:27 +08:00
parent 3b632eac46
commit bed6332688
1 changed files with 4 additions and 3 deletions

View File

@ -187,6 +187,7 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
raise ValueError(e.__str__())
parameter_dict = {}
if checkpoint_list.model_type:
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))