fix shufflenetv1 n_class
This commit is contained in:
parent
7f71a99993
commit
20275fa1b9
|
@ -39,7 +39,7 @@ def test():
|
|||
# step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = shufflenetv1(model_size=config.model_size)
|
||||
net = shufflenetv1(model_size=config.model_size, n_class=config.num_classes)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
|
|
|
@ -38,7 +38,7 @@ if config.device_target == "Ascend":
|
|||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_export():
|
||||
net = ShuffleNetV1(model_size=config.model_size)
|
||||
net = ShuffleNetV1(model_size=config.model_size, n_class=config.num_classes)
|
||||
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
|
|
@ -58,7 +58,7 @@ def train():
|
|||
context.set_context(device_id=config.device_id)
|
||||
|
||||
# define network
|
||||
net = ShuffleNetV1(model_size=config.model_size)
|
||||
net = ShuffleNetV1(model_size=config.model_size, n_class=config.num_classes)
|
||||
|
||||
# define loss
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor,
|
||||
|
|
Loading…
Reference in New Issue