fix shufflenetv1 n_class

This commit is contained in:
chenhaozhe 2021-08-02 19:40:07 +08:00
parent 7f71a99993
commit 20275fa1b9
3 changed files with 3 additions and 3 deletions

View File

@ -39,7 +39,7 @@ def test():
# step_size = dataset.get_dataset_size()
# define net
net = shufflenetv1(model_size=config.model_size)
net = shufflenetv1(model_size=config.model_size, n_class=config.num_classes)
# load checkpoint
param_dict = load_checkpoint(config.ckpt_path)

View File

@ -38,7 +38,7 @@ if config.device_target == "Ascend":
@moxing_wrapper(pre_process=modelarts_pre_process)
def model_export():
net = ShuffleNetV1(model_size=config.model_size)
net = ShuffleNetV1(model_size=config.model_size, n_class=config.num_classes)
param_dict = load_checkpoint(config.ckpt_path)
load_param_into_net(net, param_dict)

View File

@ -58,7 +58,7 @@ def train():
context.set_context(device_id=config.device_id)
# define network
net = ShuffleNetV1(model_size=config.model_size)
net = ShuffleNetV1(model_size=config.model_size, n_class=config.num_classes)
# define loss
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor,