Add the test cases of the len() method and the data types supported by broadcastto operator on the GPU

This commit is contained in:
liu-yongqi-63 2023-03-02 15:31:38 +08:00
parent cf083546aa
commit 119100618f
2 changed files with 203 additions and 11 deletions

View File

@ -138,24 +138,62 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
&BroadcastToGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&BroadcastToGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
&BroadcastToGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&BroadcastToGpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&BroadcastToGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&BroadcastToGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&BroadcastToGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
&BroadcastToGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&BroadcastToGpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
@ -171,6 +209,11 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&BroadcastToGpuKernelMod::LaunchKernel<half>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
@ -186,6 +229,41 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex64),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
@ -201,6 +279,11 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&BroadcastToGpuKernelMod::LaunchKernel<half>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
@ -216,6 +299,41 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
};
std::vector<KernelAttr> BroadcastToGpuKernelMod::GetOpSupport() {

View File

@ -44,15 +44,19 @@ def test_imagenet_rawdata_dataset_size():
"""
ds_total = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR)
assert ds_total.get_dataset_size() == 6
assert len(ds_total) == 6
ds_shard_1_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 6
assert len(ds_shard_1_0) == 6
ds_shard_2_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 3
assert len(ds_shard_2_0) == 3
ds_shard_3_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 2
assert len(ds_shard_3_0) == 2
def test_imagenet_tf_file_dataset_size():
@ -63,20 +67,25 @@ def test_imagenet_tf_file_dataset_size():
"""
ds_total = ds.TFRecordDataset(IMAGENET_TFFILE_DIR)
assert ds_total.get_dataset_size() == 12
assert len(ds_total) == 12
ds_shard_1_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=1, shard_id=0, shard_equal_rows=True)
assert ds_shard_1_0.get_dataset_size() == 12
assert len(ds_shard_1_0) == 12
ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0, shard_equal_rows=True)
assert ds_shard_2_0.get_dataset_size() == 6
assert len(ds_shard_2_0) == 6
ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0, shard_equal_rows=True)
assert ds_shard_3_0.get_dataset_size() == 4
assert len(ds_shard_3_0) == 4
count = 0
for _ in ds_shard_3_0.create_dict_iterator(num_epochs=1):
count += 1
assert ds_shard_3_0.get_dataset_size() == count
assert len(ds_shard_3_0) == count
# shard_equal_rows is set to False therefore, get_dataset_size must return count
ds_shard_4_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=4, shard_id=0)
@ -84,6 +93,7 @@ def test_imagenet_tf_file_dataset_size():
for _ in ds_shard_4_0.create_dict_iterator(num_epochs=1):
count += 1
assert ds_shard_4_0.get_dataset_size() == count
assert len(ds_shard_4_0) == count
def test_mnist_dataset_size():
@ -94,23 +104,36 @@ def test_mnist_dataset_size():
"""
ds_total = ds.MnistDataset(MNIST_DATA_DIR)
assert ds_total.get_dataset_size() == 10000
assert len(ds_total) == 10000
# test get dataset_size with the usage arg
test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size()
test_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="test")
train_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="train")
all_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="all")
test_size = test_dataset.get_dataset_size()
assert test_size == 10000
train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size()
assert len(test_dataset) == 10000
train_size = train_dataset.get_dataset_size()
assert train_size == 0
all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size()
assert len(train_dataset) == train_size
all_size = all_dataset.get_dataset_size()
assert all_size == 10000
assert len(all_dataset) == 10000
ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 10000
assert len(ds_shard_1_0) == 10000
ds_shard_2_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 5000
assert len(ds_shard_2_0) == 5000
ds_shard_3_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 3334
assert len(ds_shard_3_0) == 3334
def test_mind_dataset_size():
@ -121,9 +144,11 @@ def test_mind_dataset_size():
"""
dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0")
assert dataset.get_dataset_size() == 20
assert len(dataset) == 20
dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 10
assert len(dataset_shard_2_0) == 10
def test_manifest_dataset_size():
@ -134,15 +159,19 @@ def test_manifest_dataset_size():
"""
ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE)
assert ds_total.get_dataset_size() == 4
assert len(ds_total) == 4
ds_shard_1_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 4
assert len(ds_shard_1_0) == 4
ds_shard_2_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 2
assert len(ds_shard_2_0) == 2
ds_shard_3_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 2
assert len(ds_shard_3_0) == 2
def test_cifar10_dataset_size():
@ -153,27 +182,40 @@ def test_cifar10_dataset_size():
"""
ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR)
assert ds_total.get_dataset_size() == 10000
assert len(ds_total) == 10000
# test get_dataset_size with usage flag
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 10000
train_cifar10dataset = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train")
train_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train")
all_cifar10dataset = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all")
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
train_size = train_cifar100dataset.get_dataset_size()
assert train_size == 0
assert len(train_cifar100dataset) == train_size
train_size = train_cifar10dataset.get_dataset_size()
assert train_size == 10000
assert len(train_cifar10dataset) == 10000
all_size = all_cifar10dataset.get_dataset_size()
assert all_size == 10000
assert len(all_cifar10dataset) == 10000
ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 10000
assert len(ds_shard_1_0) == 10000
ds_shard_2_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 5000
assert len(ds_shard_2_0) == 5000
ds_shard_3_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 3334
assert len(ds_shard_3_0) == 3334
ds_shard_7_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=7, shard_id=0)
assert ds_shard_7_0.get_dataset_size() == 1429
assert len(ds_shard_7_0) == 1429
def test_cifar100_dataset_size():
@ -184,21 +226,31 @@ def test_cifar100_dataset_size():
"""
ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR)
assert ds_total.get_dataset_size() == 10000
assert len(ds_total) == 10000
# test get_dataset_size with usage flag
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
test_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test")
all_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all")
test_size = test_cifar100dataset.get_dataset_size()
assert test_size == 10000
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
assert len(test_cifar100dataset) == 10000
all_size = all_cifar100dataset.get_dataset_size()
assert all_size == 10000
assert len(all_cifar100dataset) == 10000
ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 10000
assert len(ds_shard_1_0) == 10000
ds_shard_2_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 5000
assert len(ds_shard_2_0) == 5000
ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 3334
assert len(ds_shard_3_0) == 3334
def test_voc_dataset_size():
@ -209,10 +261,12 @@ def test_voc_dataset_size():
"""
dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
assert dataset.get_dataset_size() == 10
assert len(dataset) == 10
dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True,
num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 5
assert len(dataset_shard_2_0) == 5
def test_coco_dataset_size():
@ -224,10 +278,12 @@ def test_coco_dataset_size():
dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
decode=True, shuffle=False)
assert dataset.get_dataset_size() == 6
assert len(dataset) == 6
dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 3
assert len(dataset_shard_2_0) == 3
def test_celeba_dataset_size():
@ -238,9 +294,11 @@ def test_celeba_dataset_size():
"""
dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
assert dataset.get_dataset_size() == 4
assert len(dataset) == 4
dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
assert len(dataset_shard_2_0) == 2
def test_clue_dataset_size():
@ -251,9 +309,11 @@ def test_clue_dataset_size():
"""
dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False)
assert dataset.get_dataset_size() == 3
assert len(dataset) == 3
dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
assert len(dataset_shard_2_0) == 2
def test_csv_dataset_size():
@ -265,10 +325,12 @@ def test_csv_dataset_size():
dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
shuffle=False)
assert dataset.get_dataset_size() == 3
assert len(dataset) == 3
dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
assert len(dataset_shard_2_0) == 2
def test_text_file_dataset_size():
@ -279,9 +341,11 @@ def test_text_file_dataset_size():
"""
dataset = ds.TextFileDataset(TEXT_DATA_FILE)
assert dataset.get_dataset_size() == 3
assert len(dataset) == 3
dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
assert len(dataset_shard_2_0) == 2
def test_padded_dataset_size():
@ -292,6 +356,7 @@ def test_padded_dataset_size():
"""
dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}])
assert dataset.get_dataset_size() == 2
assert len(dataset) == 2
def test_pipeline_get_dataset_size():
@ -302,25 +367,32 @@ def test_pipeline_get_dataset_size():
"""
dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False)
assert dataset.get_dataset_size() == 12
assert len(dataset) == 12
dataset = dataset.shuffle(buffer_size=3)
assert dataset.get_dataset_size() == 12
assert len(dataset) == 12
decode_op = vision.Decode()
resize_op = vision.RandomResize(10)
dataset = dataset.map([decode_op, resize_op], input_columns=["image"])
assert dataset.get_dataset_size() == 12
assert len(dataset) == 12
dataset = dataset.batch(batch_size=3)
assert dataset.get_dataset_size() == 4
assert len(dataset) == 4
dataset = dataset.repeat(count=2)
assert dataset.get_dataset_size() == 8
assert len(dataset) == 8
tf1 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True)
tf2 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True)
assert tf2.concat(tf1).get_dataset_size() == 24
tf3 = tf2.concat(tf1)
assert tf3.get_dataset_size() == 24
assert len(tf3) == 24
def test_distributed_get_dataset_size():
@ -332,6 +404,7 @@ def test_distributed_get_dataset_size():
# Test get dataset size when num_samples is less than num_per_shard (10000/4 = 2500)
dataset1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=2000, num_shards=4, shard_id=0)
assert dataset1.get_dataset_size() == 2000
assert len(dataset1) == 2000
count1 = 0
for _ in dataset1.create_dict_iterator(num_epochs=1):
@ -341,6 +414,7 @@ def test_distributed_get_dataset_size():
# Test get dataset size when num_samples is more than num_per_shard (10000/4 = 2500)
dataset2 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3000, num_shards=4, shard_id=0)
assert dataset2.get_dataset_size() == 2500
assert len(dataset2) == 2500
count2 = 0
for _ in dataset2.create_dict_iterator(num_epochs=1):