forked from OSSInnovation/mindspore
!6699 add dataset size check
Merge pull request !6699 from wukesong/add-dataset-check
This commit is contained in:
commit
9a1cfe76c7
|
@ -77,5 +77,8 @@ if __name__ == "__main__":
|
|||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
if ds_eval.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("result : {}".format(result))
|
||||
|
|
|
@ -91,6 +91,9 @@ if __name__ == "__main__":
|
|||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
if ds_train.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
network = AlexNet(cfg.num_classes)
|
||||
|
||||
loss_scale_manager = None
|
||||
|
|
|
@ -57,5 +57,8 @@ if __name__ == "__main__":
|
|||
ds_eval = create_dataset(os.path.join(args.data_path, "test"),
|
||||
cfg.batch_size,
|
||||
1)
|
||||
if ds_eval.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== {} ==============".format(acc))
|
||||
|
|
|
@ -50,6 +50,8 @@ if __name__ == "__main__":
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
||||
cfg.batch_size)
|
||||
if ds_train.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
|
Loading…
Reference in New Issue