From 119100618f2d1f6cf0f84265c475f387f5247b20 Mon Sep 17 00:00:00 2001 From: liu-yongqi-63 Date: Thu, 2 Mar 2023 15:31:38 +0800 Subject: [PATCH] Add the test cases of the len() method and the data types supported by broadcastto operator on the GPU --- .../kernel/arrays/broadcast_to_gpu_kernel.cc | 118 ++++++++++++++++++ .../dataset/test_datasets_get_dataset_size.py | 96 ++++++++++++-- 2 files changed, 203 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.cc index 408f865e678..b410abaadc2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.cc @@ -138,24 +138,62 @@ std::vector}, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64), + &BroadcastToGpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex128), + &BroadcastToGpuKernelMod::LaunchKernel>}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64), + &BroadcastToGpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex128), + &BroadcastToGpuKernelMod::LaunchKernel>}, {KernelAttr() .AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) @@ -171,6 +209,11 @@ std::vector}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeInt16) .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) @@ -186,6 +229,41 @@ std::vector}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt32), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt32) + .AddOutputAttr(kNumberTypeComplex64), + &BroadcastToGpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex128), + &BroadcastToGpuKernelMod::LaunchKernel>}, {KernelAttr() .AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) @@ -201,6 +279,11 @@ std::vector}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeInt16) .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) @@ -216,6 +299,41 @@ std::vector}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt32), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt64), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeBool), + &BroadcastToGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex64), + &BroadcastToGpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex128), + &BroadcastToGpuKernelMod::LaunchKernel>}, }; std::vector BroadcastToGpuKernelMod::GetOpSupport() { diff --git a/tests/ut/python/dataset/test_datasets_get_dataset_size.py b/tests/ut/python/dataset/test_datasets_get_dataset_size.py index 74536050e13..a4c0d003892 100644 --- a/tests/ut/python/dataset/test_datasets_get_dataset_size.py +++ b/tests/ut/python/dataset/test_datasets_get_dataset_size.py @@ -44,15 +44,19 @@ def test_imagenet_rawdata_dataset_size(): """ ds_total = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR) assert ds_total.get_dataset_size() == 6 + assert len(ds_total) == 6 ds_shard_1_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=1, shard_id=0) assert ds_shard_1_0.get_dataset_size() == 6 + assert len(ds_shard_1_0) == 6 ds_shard_2_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=2, shard_id=0) assert ds_shard_2_0.get_dataset_size() == 3 + assert len(ds_shard_2_0) == 3 ds_shard_3_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=3, shard_id=0) assert ds_shard_3_0.get_dataset_size() == 2 + assert len(ds_shard_3_0) == 2 def test_imagenet_tf_file_dataset_size(): @@ -63,20 +67,25 @@ def test_imagenet_tf_file_dataset_size(): """ ds_total = ds.TFRecordDataset(IMAGENET_TFFILE_DIR) assert ds_total.get_dataset_size() == 12 + assert len(ds_total) == 12 ds_shard_1_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=1, shard_id=0, shard_equal_rows=True) assert ds_shard_1_0.get_dataset_size() == 12 + assert len(ds_shard_1_0) == 12 ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0, shard_equal_rows=True) assert ds_shard_2_0.get_dataset_size() == 6 + assert len(ds_shard_2_0) == 6 ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0, shard_equal_rows=True) assert ds_shard_3_0.get_dataset_size() == 4 + assert len(ds_shard_3_0) == 4 count = 0 for _ in ds_shard_3_0.create_dict_iterator(num_epochs=1): count += 1 assert ds_shard_3_0.get_dataset_size() == count + assert len(ds_shard_3_0) == count # shard_equal_rows is set to False therefore, get_dataset_size must return count ds_shard_4_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=4, shard_id=0) @@ -84,6 +93,7 @@ def test_imagenet_tf_file_dataset_size(): for _ in ds_shard_4_0.create_dict_iterator(num_epochs=1): count += 1 assert ds_shard_4_0.get_dataset_size() == count + assert len(ds_shard_4_0) == count def test_mnist_dataset_size(): @@ -94,23 +104,36 @@ def test_mnist_dataset_size(): """ ds_total = ds.MnistDataset(MNIST_DATA_DIR) assert ds_total.get_dataset_size() == 10000 + assert len(ds_total) == 10000 # test get dataset_size with the usage arg - test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size() + test_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="test") + train_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="train") + all_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="all") + + test_size = test_dataset.get_dataset_size() assert test_size == 10000 - train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size() + assert len(test_dataset) == 10000 + + train_size = train_dataset.get_dataset_size() assert train_size == 0 - all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size() + assert len(train_dataset) == train_size + + all_size = all_dataset.get_dataset_size() assert all_size == 10000 + assert len(all_dataset) == 10000 ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0) assert ds_shard_1_0.get_dataset_size() == 10000 + assert len(ds_shard_1_0) == 10000 ds_shard_2_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=2, shard_id=0) assert ds_shard_2_0.get_dataset_size() == 5000 + assert len(ds_shard_2_0) == 5000 ds_shard_3_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=3, shard_id=0) assert ds_shard_3_0.get_dataset_size() == 3334 + assert len(ds_shard_3_0) == 3334 def test_mind_dataset_size(): @@ -121,9 +144,11 @@ def test_mind_dataset_size(): """ dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0") assert dataset.get_dataset_size() == 20 + assert len(dataset) == 20 dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 10 + assert len(dataset_shard_2_0) == 10 def test_manifest_dataset_size(): @@ -134,15 +159,19 @@ def test_manifest_dataset_size(): """ ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE) assert ds_total.get_dataset_size() == 4 + assert len(ds_total) == 4 ds_shard_1_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=1, shard_id=0) assert ds_shard_1_0.get_dataset_size() == 4 + assert len(ds_shard_1_0) == 4 ds_shard_2_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=2, shard_id=0) assert ds_shard_2_0.get_dataset_size() == 2 + assert len(ds_shard_2_0) == 2 ds_shard_3_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=3, shard_id=0) assert ds_shard_3_0.get_dataset_size() == 2 + assert len(ds_shard_3_0) == 2 def test_cifar10_dataset_size(): @@ -153,27 +182,40 @@ def test_cifar10_dataset_size(): """ ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR) assert ds_total.get_dataset_size() == 10000 + assert len(ds_total) == 10000 # test get_dataset_size with usage flag - train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size() - assert train_size == 0 - train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size() - assert train_size == 10000 + train_cifar10dataset = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train") + train_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train") + all_cifar10dataset = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all") - all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size() + train_size = train_cifar100dataset.get_dataset_size() + assert train_size == 0 + assert len(train_cifar100dataset) == train_size + + train_size = train_cifar10dataset.get_dataset_size() + assert train_size == 10000 + assert len(train_cifar10dataset) == 10000 + + all_size = all_cifar10dataset.get_dataset_size() assert all_size == 10000 + assert len(all_cifar10dataset) == 10000 ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0) assert ds_shard_1_0.get_dataset_size() == 10000 + assert len(ds_shard_1_0) == 10000 ds_shard_2_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=2, shard_id=0) assert ds_shard_2_0.get_dataset_size() == 5000 + assert len(ds_shard_2_0) == 5000 ds_shard_3_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=3, shard_id=0) assert ds_shard_3_0.get_dataset_size() == 3334 + assert len(ds_shard_3_0) == 3334 ds_shard_7_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=7, shard_id=0) assert ds_shard_7_0.get_dataset_size() == 1429 + assert len(ds_shard_7_0) == 1429 def test_cifar100_dataset_size(): @@ -184,21 +226,31 @@ def test_cifar100_dataset_size(): """ ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR) assert ds_total.get_dataset_size() == 10000 + assert len(ds_total) == 10000 # test get_dataset_size with usage flag - test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size() + test_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test") + all_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all") + + test_size = test_cifar100dataset.get_dataset_size() assert test_size == 10000 - all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size() + assert len(test_cifar100dataset) == 10000 + + all_size = all_cifar100dataset.get_dataset_size() assert all_size == 10000 + assert len(all_cifar100dataset) == 10000 ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0) assert ds_shard_1_0.get_dataset_size() == 10000 + assert len(ds_shard_1_0) == 10000 ds_shard_2_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=2, shard_id=0) assert ds_shard_2_0.get_dataset_size() == 5000 + assert len(ds_shard_2_0) == 5000 ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0) assert ds_shard_3_0.get_dataset_size() == 3334 + assert len(ds_shard_3_0) == 3334 def test_voc_dataset_size(): @@ -209,10 +261,12 @@ def test_voc_dataset_size(): """ dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) assert dataset.get_dataset_size() == 10 + assert len(dataset) == 10 dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 5 + assert len(dataset_shard_2_0) == 5 def test_coco_dataset_size(): @@ -224,10 +278,12 @@ def test_coco_dataset_size(): dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, shuffle=False) assert dataset.get_dataset_size() == 6 + assert len(dataset) == 6 dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, shuffle=False, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 3 + assert len(dataset_shard_2_0) == 3 def test_celeba_dataset_size(): @@ -238,9 +294,11 @@ def test_celeba_dataset_size(): """ dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) assert dataset.get_dataset_size() == 4 + assert len(dataset) == 4 dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 2 + assert len(dataset_shard_2_0) == 2 def test_clue_dataset_size(): @@ -251,9 +309,11 @@ def test_clue_dataset_size(): """ dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False) assert dataset.get_dataset_size() == 3 + assert len(dataset) == 3 dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 2 + assert len(dataset_shard_2_0) == 2 def test_csv_dataset_size(): @@ -265,10 +325,12 @@ def test_csv_dataset_size(): dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False) assert dataset.get_dataset_size() == 3 + assert len(dataset) == 3 dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 2 + assert len(dataset_shard_2_0) == 2 def test_text_file_dataset_size(): @@ -279,9 +341,11 @@ def test_text_file_dataset_size(): """ dataset = ds.TextFileDataset(TEXT_DATA_FILE) assert dataset.get_dataset_size() == 3 + assert len(dataset) == 3 dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 2 + assert len(dataset_shard_2_0) == 2 def test_padded_dataset_size(): @@ -292,6 +356,7 @@ def test_padded_dataset_size(): """ dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}]) assert dataset.get_dataset_size() == 2 + assert len(dataset) == 2 def test_pipeline_get_dataset_size(): @@ -302,25 +367,32 @@ def test_pipeline_get_dataset_size(): """ dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False) assert dataset.get_dataset_size() == 12 + assert len(dataset) == 12 dataset = dataset.shuffle(buffer_size=3) assert dataset.get_dataset_size() == 12 + assert len(dataset) == 12 decode_op = vision.Decode() resize_op = vision.RandomResize(10) dataset = dataset.map([decode_op, resize_op], input_columns=["image"]) assert dataset.get_dataset_size() == 12 + assert len(dataset) == 12 dataset = dataset.batch(batch_size=3) assert dataset.get_dataset_size() == 4 + assert len(dataset) == 4 dataset = dataset.repeat(count=2) assert dataset.get_dataset_size() == 8 + assert len(dataset) == 8 tf1 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True) tf2 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True) - assert tf2.concat(tf1).get_dataset_size() == 24 + tf3 = tf2.concat(tf1) + assert tf3.get_dataset_size() == 24 + assert len(tf3) == 24 def test_distributed_get_dataset_size(): @@ -332,6 +404,7 @@ def test_distributed_get_dataset_size(): # Test get dataset size when num_samples is less than num_per_shard (10000/4 = 2500) dataset1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=2000, num_shards=4, shard_id=0) assert dataset1.get_dataset_size() == 2000 + assert len(dataset1) == 2000 count1 = 0 for _ in dataset1.create_dict_iterator(num_epochs=1): @@ -341,6 +414,7 @@ def test_distributed_get_dataset_size(): # Test get dataset size when num_samples is more than num_per_shard (10000/4 = 2500) dataset2 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3000, num_shards=4, shard_id=0) assert dataset2.get_dataset_size() == 2500 + assert len(dataset2) == 2500 count2 = 0 for _ in dataset2.create_dict_iterator(num_epochs=1):