!22715 【轻量级 PR】:update model_zoo/research/cv/SE_ResNeXt50/eval.py.

Merge pull request !22715 from yexijoe/N/A
This commit is contained in:
i-robot 2021-09-09 07:33:37 +00:00 committed by Gitee
commit df67984875
1 changed files with 6 additions and 6 deletions

View File

@ -62,14 +62,8 @@ if __name__ == '__main__':
if args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = senets.se_resnext50_32x4d(cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
else:
raise ValueError("dataset is not support.")
@ -78,6 +72,12 @@ if __name__ == '__main__':
if device_target == "Ascend":
context.set_context(device_id=cfg.device_id)
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = senets.se_resnext50_32x4d(cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
file_list = os.listdir(args_opt.checkpoint_path)
for filename in file_list:
de_path = os.path.join(args_opt.checkpoint_path, filename)