!6699 add dataset size check

Merge pull request !6699 from wukesong/add-dataset-check
This commit is contained in:
mindspore-ci-bot 2020-09-23 11:39:51 +08:00 committed by Gitee
commit 9a1cfe76c7
4 changed files with 11 additions and 0 deletions

View File

@ -77,5 +77,8 @@ if __name__ == "__main__":
else:
raise ValueError("Unsupport dataset.")
if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("result : {}".format(result))

View File

@ -91,6 +91,9 @@ if __name__ == "__main__":
else:
raise ValueError("Unsupport dataset.")
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
network = AlexNet(cfg.num_classes)
loss_scale_manager = None

View File

@ -57,5 +57,8 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"),
cfg.batch_size,
1)
if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))

View File

@ -50,6 +50,8 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")