forked from mindspore-Ecosystem/mindspore
!22715 【轻量级 PR】:update model_zoo/research/cv/SE_ResNeXt50/eval.py.
Merge pull request !22715 from yexijoe/N/A
This commit is contained in:
commit
df67984875
|
@ -62,14 +62,8 @@ if __name__ == '__main__':
|
|||
|
||||
if args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
net = senets.se_resnext50_32x4d(cfg.num_classes)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
|
@ -78,6 +72,12 @@ if __name__ == '__main__':
|
|||
if device_target == "Ascend":
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
net = senets.se_resnext50_32x4d(cfg.num_classes)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
file_list = os.listdir(args_opt.checkpoint_path)
|
||||
for filename in file_list:
|
||||
de_path = os.path.join(args_opt.checkpoint_path, filename)
|
||||
|
|
Loading…
Reference in New Issue