forked from mindspore-Ecosystem/mindspore
!49655 Add the test cases of the len() method and the data types supported by broadcastto operator on the GPU
Merge pull request !49655 from 刘勇琪/master-decrypt-column
This commit is contained in:
commit
0db81d3bb6
|
@ -138,24 +138,62 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
|
|||
&BroadcastToGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
|
@ -171,6 +209,11 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
|
|||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
|
@ -186,6 +229,41 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
|
|||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
|
@ -201,6 +279,11 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
|
|||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
|
@ -216,6 +299,41 @@ std::vector<std::pair<KernelAttr, BroadcastToGpuKernelMod::BroadcastToLaunchFunc
|
|||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&BroadcastToGpuKernelMod::LaunchKernel<utils::Complex<double>>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> BroadcastToGpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -44,15 +44,19 @@ def test_imagenet_rawdata_dataset_size():
|
|||
"""
|
||||
ds_total = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 6
|
||||
assert len(ds_total) == 6
|
||||
|
||||
ds_shard_1_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 6
|
||||
assert len(ds_shard_1_0) == 6
|
||||
|
||||
ds_shard_2_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 3
|
||||
assert len(ds_shard_2_0) == 3
|
||||
|
||||
ds_shard_3_0 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 2
|
||||
assert len(ds_shard_3_0) == 2
|
||||
|
||||
|
||||
def test_imagenet_tf_file_dataset_size():
|
||||
|
@ -63,20 +67,25 @@ def test_imagenet_tf_file_dataset_size():
|
|||
"""
|
||||
ds_total = ds.TFRecordDataset(IMAGENET_TFFILE_DIR)
|
||||
assert ds_total.get_dataset_size() == 12
|
||||
assert len(ds_total) == 12
|
||||
|
||||
ds_shard_1_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=1, shard_id=0, shard_equal_rows=True)
|
||||
assert ds_shard_1_0.get_dataset_size() == 12
|
||||
assert len(ds_shard_1_0) == 12
|
||||
|
||||
ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0, shard_equal_rows=True)
|
||||
assert ds_shard_2_0.get_dataset_size() == 6
|
||||
assert len(ds_shard_2_0) == 6
|
||||
|
||||
ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0, shard_equal_rows=True)
|
||||
assert ds_shard_3_0.get_dataset_size() == 4
|
||||
assert len(ds_shard_3_0) == 4
|
||||
|
||||
count = 0
|
||||
for _ in ds_shard_3_0.create_dict_iterator(num_epochs=1):
|
||||
count += 1
|
||||
assert ds_shard_3_0.get_dataset_size() == count
|
||||
assert len(ds_shard_3_0) == count
|
||||
|
||||
# shard_equal_rows is set to False therefore, get_dataset_size must return count
|
||||
ds_shard_4_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=4, shard_id=0)
|
||||
|
@ -84,6 +93,7 @@ def test_imagenet_tf_file_dataset_size():
|
|||
for _ in ds_shard_4_0.create_dict_iterator(num_epochs=1):
|
||||
count += 1
|
||||
assert ds_shard_4_0.get_dataset_size() == count
|
||||
assert len(ds_shard_4_0) == count
|
||||
|
||||
|
||||
def test_mnist_dataset_size():
|
||||
|
@ -94,23 +104,36 @@ def test_mnist_dataset_size():
|
|||
"""
|
||||
ds_total = ds.MnistDataset(MNIST_DATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 10000
|
||||
assert len(ds_total) == 10000
|
||||
|
||||
# test get dataset_size with the usage arg
|
||||
test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size()
|
||||
test_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="test")
|
||||
train_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="train")
|
||||
all_dataset = ds.MnistDataset(MNIST_DATA_DIR, usage="all")
|
||||
|
||||
test_size = test_dataset.get_dataset_size()
|
||||
assert test_size == 10000
|
||||
train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert len(test_dataset) == 10000
|
||||
|
||||
train_size = train_dataset.get_dataset_size()
|
||||
assert train_size == 0
|
||||
all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert len(train_dataset) == train_size
|
||||
|
||||
all_size = all_dataset.get_dataset_size()
|
||||
assert all_size == 10000
|
||||
assert len(all_dataset) == 10000
|
||||
|
||||
ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 10000
|
||||
assert len(ds_shard_1_0) == 10000
|
||||
|
||||
ds_shard_2_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 5000
|
||||
assert len(ds_shard_2_0) == 5000
|
||||
|
||||
ds_shard_3_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
assert len(ds_shard_3_0) == 3334
|
||||
|
||||
|
||||
def test_mind_dataset_size():
|
||||
|
@ -121,9 +144,11 @@ def test_mind_dataset_size():
|
|||
"""
|
||||
dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0")
|
||||
assert dataset.get_dataset_size() == 20
|
||||
assert len(dataset) == 20
|
||||
|
||||
dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 10
|
||||
assert len(dataset_shard_2_0) == 10
|
||||
|
||||
|
||||
def test_manifest_dataset_size():
|
||||
|
@ -134,15 +159,19 @@ def test_manifest_dataset_size():
|
|||
"""
|
||||
ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE)
|
||||
assert ds_total.get_dataset_size() == 4
|
||||
assert len(ds_total) == 4
|
||||
|
||||
ds_shard_1_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 4
|
||||
assert len(ds_shard_1_0) == 4
|
||||
|
||||
ds_shard_2_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 2
|
||||
assert len(ds_shard_2_0) == 2
|
||||
|
||||
ds_shard_3_0 = ds.ManifestDataset(MANIFEST_DATA_FILE, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 2
|
||||
assert len(ds_shard_3_0) == 2
|
||||
|
||||
|
||||
def test_cifar10_dataset_size():
|
||||
|
@ -153,27 +182,40 @@ def test_cifar10_dataset_size():
|
|||
"""
|
||||
ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 10000
|
||||
assert len(ds_total) == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 10000
|
||||
train_cifar10dataset = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train")
|
||||
train_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train")
|
||||
all_cifar10dataset = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all")
|
||||
|
||||
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
|
||||
train_size = train_cifar100dataset.get_dataset_size()
|
||||
assert train_size == 0
|
||||
assert len(train_cifar100dataset) == train_size
|
||||
|
||||
train_size = train_cifar10dataset.get_dataset_size()
|
||||
assert train_size == 10000
|
||||
assert len(train_cifar10dataset) == 10000
|
||||
|
||||
all_size = all_cifar10dataset.get_dataset_size()
|
||||
assert all_size == 10000
|
||||
assert len(all_cifar10dataset) == 10000
|
||||
|
||||
ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 10000
|
||||
assert len(ds_shard_1_0) == 10000
|
||||
|
||||
ds_shard_2_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 5000
|
||||
assert len(ds_shard_2_0) == 5000
|
||||
|
||||
ds_shard_3_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
assert len(ds_shard_3_0) == 3334
|
||||
|
||||
ds_shard_7_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=7, shard_id=0)
|
||||
assert ds_shard_7_0.get_dataset_size() == 1429
|
||||
assert len(ds_shard_7_0) == 1429
|
||||
|
||||
|
||||
def test_cifar100_dataset_size():
|
||||
|
@ -184,21 +226,31 @@ def test_cifar100_dataset_size():
|
|||
"""
|
||||
ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR)
|
||||
assert ds_total.get_dataset_size() == 10000
|
||||
assert len(ds_total) == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
|
||||
test_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test")
|
||||
all_cifar100dataset = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all")
|
||||
|
||||
test_size = test_cifar100dataset.get_dataset_size()
|
||||
assert test_size == 10000
|
||||
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert len(test_cifar100dataset) == 10000
|
||||
|
||||
all_size = all_cifar100dataset.get_dataset_size()
|
||||
assert all_size == 10000
|
||||
assert len(all_cifar100dataset) == 10000
|
||||
|
||||
ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0)
|
||||
assert ds_shard_1_0.get_dataset_size() == 10000
|
||||
assert len(ds_shard_1_0) == 10000
|
||||
|
||||
ds_shard_2_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 5000
|
||||
assert len(ds_shard_2_0) == 5000
|
||||
|
||||
ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
assert len(ds_shard_3_0) == 3334
|
||||
|
||||
|
||||
def test_voc_dataset_size():
|
||||
|
@ -209,10 +261,12 @@ def test_voc_dataset_size():
|
|||
"""
|
||||
dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
|
||||
assert dataset.get_dataset_size() == 10
|
||||
assert len(dataset) == 10
|
||||
|
||||
dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True,
|
||||
num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 5
|
||||
assert len(dataset_shard_2_0) == 5
|
||||
|
||||
|
||||
def test_coco_dataset_size():
|
||||
|
@ -224,10 +278,12 @@ def test_coco_dataset_size():
|
|||
dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
|
||||
decode=True, shuffle=False)
|
||||
assert dataset.get_dataset_size() == 6
|
||||
assert len(dataset) == 6
|
||||
|
||||
dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
|
||||
shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 3
|
||||
assert len(dataset_shard_2_0) == 3
|
||||
|
||||
|
||||
def test_celeba_dataset_size():
|
||||
|
@ -238,9 +294,11 @@ def test_celeba_dataset_size():
|
|||
"""
|
||||
dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
|
||||
assert dataset.get_dataset_size() == 4
|
||||
assert len(dataset) == 4
|
||||
|
||||
dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
assert len(dataset_shard_2_0) == 2
|
||||
|
||||
|
||||
def test_clue_dataset_size():
|
||||
|
@ -251,9 +309,11 @@ def test_clue_dataset_size():
|
|||
"""
|
||||
dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
assert len(dataset) == 3
|
||||
|
||||
dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
assert len(dataset_shard_2_0) == 2
|
||||
|
||||
|
||||
def test_csv_dataset_size():
|
||||
|
@ -265,10 +325,12 @@ def test_csv_dataset_size():
|
|||
dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
|
||||
shuffle=False)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
assert len(dataset) == 3
|
||||
|
||||
dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
|
||||
shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
assert len(dataset_shard_2_0) == 2
|
||||
|
||||
|
||||
def test_text_file_dataset_size():
|
||||
|
@ -279,9 +341,11 @@ def test_text_file_dataset_size():
|
|||
"""
|
||||
dataset = ds.TextFileDataset(TEXT_DATA_FILE)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
assert len(dataset) == 3
|
||||
|
||||
dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
assert len(dataset_shard_2_0) == 2
|
||||
|
||||
|
||||
def test_padded_dataset_size():
|
||||
|
@ -292,6 +356,7 @@ def test_padded_dataset_size():
|
|||
"""
|
||||
dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}])
|
||||
assert dataset.get_dataset_size() == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
|
||||
def test_pipeline_get_dataset_size():
|
||||
|
@ -302,25 +367,32 @@ def test_pipeline_get_dataset_size():
|
|||
"""
|
||||
dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False)
|
||||
assert dataset.get_dataset_size() == 12
|
||||
assert len(dataset) == 12
|
||||
|
||||
dataset = dataset.shuffle(buffer_size=3)
|
||||
assert dataset.get_dataset_size() == 12
|
||||
assert len(dataset) == 12
|
||||
|
||||
decode_op = vision.Decode()
|
||||
resize_op = vision.RandomResize(10)
|
||||
|
||||
dataset = dataset.map([decode_op, resize_op], input_columns=["image"])
|
||||
assert dataset.get_dataset_size() == 12
|
||||
assert len(dataset) == 12
|
||||
|
||||
dataset = dataset.batch(batch_size=3)
|
||||
assert dataset.get_dataset_size() == 4
|
||||
assert len(dataset) == 4
|
||||
|
||||
dataset = dataset.repeat(count=2)
|
||||
assert dataset.get_dataset_size() == 8
|
||||
assert len(dataset) == 8
|
||||
|
||||
tf1 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True)
|
||||
tf2 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True)
|
||||
assert tf2.concat(tf1).get_dataset_size() == 24
|
||||
tf3 = tf2.concat(tf1)
|
||||
assert tf3.get_dataset_size() == 24
|
||||
assert len(tf3) == 24
|
||||
|
||||
|
||||
def test_distributed_get_dataset_size():
|
||||
|
@ -332,6 +404,7 @@ def test_distributed_get_dataset_size():
|
|||
# Test get dataset size when num_samples is less than num_per_shard (10000/4 = 2500)
|
||||
dataset1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=2000, num_shards=4, shard_id=0)
|
||||
assert dataset1.get_dataset_size() == 2000
|
||||
assert len(dataset1) == 2000
|
||||
|
||||
count1 = 0
|
||||
for _ in dataset1.create_dict_iterator(num_epochs=1):
|
||||
|
@ -341,6 +414,7 @@ def test_distributed_get_dataset_size():
|
|||
# Test get dataset size when num_samples is more than num_per_shard (10000/4 = 2500)
|
||||
dataset2 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3000, num_shards=4, shard_id=0)
|
||||
assert dataset2.get_dataset_size() == 2500
|
||||
assert len(dataset2) == 2500
|
||||
|
||||
count2 = 0
|
||||
for _ in dataset2.create_dict_iterator(num_epochs=1):
|
||||
|
|
Loading…
Reference in New Issue