forked from mindspore-Ecosystem/mindspore
!18892 [assistant][ops] Add new dataset operator SBD.
Merge pull request !18892 from Rainfor/sbd
This commit is contained in:
commit
ef77050348
|
@ -41,6 +41,8 @@ import weakref
|
|||
import platform
|
||||
import psutil
|
||||
import numpy as np
|
||||
from scipy.io import loadmat
|
||||
from PIL import Image
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
from mindspore._c_expression import typing
|
||||
|
@ -61,7 +63,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_celebadataset, check_minddataset, check_generatordataset, check_sync_wait, check_zip_dataset, \
|
||||
check_add_column, check_textfiledataset, check_concat, check_random_dataset, check_split, \
|
||||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_sb_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -5668,3 +5670,190 @@ class PaddedDataset(GeneratorDataset):
|
|||
super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False)
|
||||
self._dataset_size = len(dataset.padded_samples)
|
||||
self.padded_samples = padded_samples
|
||||
|
||||
|
||||
class SBDataset(GeneratorDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Semantic Boundaries Dataset.
|
||||
|
||||
The generated dataset has two columns: :py:obj:`[image, task]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`task` contains 20 images of the uint8 type if `task` is `Boundaries` otherwise
|
||||
contains 1 image of the uint8 type.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||
task (str, optional): Acceptable tasks include `Boundaries` or `Segmentation` (default=`Boundaries`).
|
||||
usage (str, optional): Acceptable usages include `train`, `val`, `train_noval` and `all` (default=`all`).
|
||||
num_samples (int, optional): The number of images to be included in the dataset.
|
||||
(default=None, all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
|
||||
order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into (default=None). When this argument is specified, `num_samples` reflects
|
||||
the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dataset_dir is not valid or does not contain data files.
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If dataset_dir is not exist.
|
||||
ValueError: If task is not in [`Boundaries`, `Segmentation`].
|
||||
ValueError: If usage is not in [`train`, `val`, `train_noval`, `all`].
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a sampler. `sampler` and `shuffle` are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter `sampler`
|
||||
- Parameter `shuffle`
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> sb_dataset_dir = "/path/to/sb_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) Get all samples from Semantic Boundaries Dataset in sequence
|
||||
>>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # 2) Randomly select 350 samples from Semantic Boundaries Dataset
|
||||
>>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_samples=350, shuffle=True)
|
||||
>>>
|
||||
>>> # 3) Get samples from Semantic Boundaries Dataset for shard 0 in a 2-way distributed training
|
||||
>>> dataset = ds.SBDataset(dataset_dir=sb_dataset_dir, num_shards=2, shard_id=0)
|
||||
>>>
|
||||
>>> # In Semantic Boundaries Dataset, each dictionary has keys "image" and "task"
|
||||
|
||||
About Semantic Boundaries Dataset.
|
||||
| The Semantic Boundaries Dataset consists of 11355 colour images. There are 8498 images' name in the train.txt,
|
||||
2857 images' name in the val.txt and 5623 images' name in the train_noval.txt. The category cls/
|
||||
contains the Segmentation and Boundaries results of category-level, the category inst/ catains the
|
||||
Segmentation and Boundaries results of instance-level.
|
||||
|
||||
| You can unzip the dataset files into the following structure and read by MindSpore's API,
|
||||
| .
|
||||
| └── benchmark_RELEASE
|
||||
| └── dataset
|
||||
| ├── img
|
||||
| | ├── 2008_000002.jpg
|
||||
| | ├── 2008_000003.jpg
|
||||
| | ├── ...
|
||||
| ├── cls
|
||||
| | ├── 2008_000002.mat
|
||||
| | ├── 2008_000003.mat
|
||||
| | ├── ...
|
||||
| ├── inst
|
||||
| | ├── 2008_000002.mat
|
||||
| | ├── 2008_000003.mat
|
||||
| | ├── ...
|
||||
| ├── train.txt
|
||||
| └── val.txt
|
||||
|
||||
.. code-block::
|
||||
|
||||
@InProceedings{BharathICCV2011,
|
||||
author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and
|
||||
Subhransu Maji and Jitendra Malik",
|
||||
title = "Semantic Contours from Inverse Detectors",
|
||||
booktitle = "International Conference on Computer Vision (ICCV)",
|
||||
year = "2011",
|
||||
"""
|
||||
|
||||
@check_sb_dataset
|
||||
def __init__(self, dataset_dir, task='Boundaries', usage='all', num_samples=None, num_parallel_workers=1,
|
||||
shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None):
|
||||
dataset = _SBDataset(dataset_dir, task, usage, decode)
|
||||
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
|
||||
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
||||
num_shards=num_shards, shard_id=shard_id)
|
||||
|
||||
|
||||
class _SBDataset():
|
||||
"""
|
||||
Dealing with the data file with .mat extension, and return one row in tuple (image, task) each time.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dir, task, usage, decode):
|
||||
self.column_list = ['image', 'task']
|
||||
self.task = task
|
||||
self.images_path = os.path.join(dataset_dir, 'img')
|
||||
self.cls_path = os.path.join(dataset_dir, 'cls')
|
||||
self._loadmat = loadmat
|
||||
self.categories = 20
|
||||
self.decode = replace_none(decode, False)
|
||||
|
||||
if usage == "all":
|
||||
image_names = []
|
||||
for item in ["train", "val"]:
|
||||
usage_path = os.path.join(dataset_dir, item + '.txt')
|
||||
if not os.path.exists(usage_path):
|
||||
raise FileNotFoundError("SBDataset: {0} not found".format(usage_path))
|
||||
with open(usage_path, 'r') as f:
|
||||
image_names += [x.strip() for x in f.readlines()]
|
||||
else:
|
||||
usage_path = os.path.join(dataset_dir, usage + '.txt')
|
||||
if not os.path.exists(usage_path):
|
||||
raise FileNotFoundError("SBDataset: {0} not found".format(usage_path))
|
||||
with open(usage_path, 'r') as f:
|
||||
image_names = [x.strip() for x in f.readlines()]
|
||||
|
||||
self.images = [os.path.join(self.images_path, i + ".jpg") for i in image_names]
|
||||
self.clss = [os.path.join(self.cls_path, i + ".mat") for i in image_names]
|
||||
|
||||
if len(self.images) != len(self.clss):
|
||||
raise ValueError("SBDataset: images count not equal to cls count")
|
||||
|
||||
self._get_data = self._get_boundaries_data if self.task == "Boundaries" else self._get_segmentation_data
|
||||
self._get_item = self._get_decode_item if self.decode else self._get_undecode_item
|
||||
|
||||
def _get_boundaries_data(self, mat_path):
|
||||
mat_data = self._loadmat(mat_path)
|
||||
return np.concatenate([np.expand_dims(mat_data['GTcls'][0][self.task][0][i][0].toarray(), axis=0)
|
||||
for i in range(self.categories)], axis=0)
|
||||
|
||||
def _get_segmentation_data(self, mat_path):
|
||||
mat_data = self._loadmat(mat_path)
|
||||
return Image.fromarray(mat_data['GTcls'][0][self.task][0])
|
||||
|
||||
def _get_decode_item(self, idx):
|
||||
return Image.open(self.images[idx]).convert('RGB'), self._get_data(self.clss[idx])
|
||||
|
||||
def _get_undecode_item(self, idx):
|
||||
return np.fromfile(self.images[idx], dtype=np.uint8), self._get_data(self.clss[idx])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._get_item(idx)
|
||||
|
|
|
@ -1329,3 +1329,34 @@ def check_to_device_send(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_sb_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle', 'decode']
|
||||
|
||||
dataset_dir = param_dict.get('dataset_dir')
|
||||
check_dir(dataset_dir)
|
||||
|
||||
usage = param_dict.get('usage')
|
||||
if usage is not None:
|
||||
check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage")
|
||||
|
||||
task = param_dict.get('task')
|
||||
if task is not None:
|
||||
check_valid_str(task, ["Boundaries", "Segmentation"], "task")
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 17 KiB |
Binary file not shown.
After Width: | Height: | Size: 26 KiB |
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
Binary file not shown.
After Width: | Height: | Size: 27 KiB |
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
Binary file not shown.
After Width: | Height: | Size: 17 KiB |
|
@ -0,0 +1,4 @@
|
|||
000001
|
||||
000002
|
||||
000003
|
||||
000004
|
|
@ -0,0 +1,4 @@
|
|||
000001
|
||||
000003
|
||||
000005
|
||||
000006
|
|
@ -0,0 +1,2 @@
|
|||
000005
|
||||
000006
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import math
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
|
||||
|
||||
DATASET_DIR = "../data/dataset/testSBData/sbd"
|
||||
|
||||
|
||||
def visualize_dataset(images, labels, task):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
image_num = len(images)
|
||||
subplot_rows = 1 if task == "Segmentation" else 4
|
||||
for i in range(image_num):
|
||||
plt.imshow(images[i])
|
||||
plt.title('Original')
|
||||
plt.savefig('./sbd_original_{}.jpg'.format(str(i)))
|
||||
if task == "Segmentation":
|
||||
plt.imshow(labels[i])
|
||||
plt.title(task)
|
||||
plt.savefig('./sbd_segmentation_{}.jpg'.format(str(i)))
|
||||
else:
|
||||
b_num = labels[i].shape[0]
|
||||
for j in range(b_num):
|
||||
plt.subplot(subplot_rows, math.ceil(b_num / subplot_rows), j + 1)
|
||||
plt.imshow(labels[i][j])
|
||||
plt.savefig('./sbd_boundaries_{}.jpg'.format(str(i)))
|
||||
plt.close()
|
||||
|
||||
|
||||
def test_sbd_basic01(plot=False):
|
||||
"""
|
||||
Validate SBDataset with different usage
|
||||
"""
|
||||
task = 'Segmentation' # Boundaries, Segmentation
|
||||
data = ds.SBDataset(DATASET_DIR, task=task, usage='all', shuffle=False, decode=True)
|
||||
count = 0
|
||||
images_list = []
|
||||
task_list = []
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
images_list.append(item['image'])
|
||||
task_list.append(item['task'])
|
||||
count = count + 1
|
||||
assert count == 6
|
||||
if plot:
|
||||
visualize_dataset(images_list, task_list, task)
|
||||
|
||||
data2 = ds.SBDataset(DATASET_DIR, task=task, usage='train', shuffle=False, decode=False)
|
||||
count = 0
|
||||
for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count = count + 1
|
||||
assert count == 4
|
||||
|
||||
data3 = ds.SBDataset(DATASET_DIR, task=task, usage='val', shuffle=False, decode=False)
|
||||
count = 0
|
||||
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count = count + 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_sbd_basic02():
|
||||
"""
|
||||
Validate SBDataset with repeat and batch operation
|
||||
"""
|
||||
# Boundaries, Segmentation
|
||||
# case 1: test num_samples
|
||||
data1 = ds.SBDataset(DATASET_DIR, task='Boundaries', usage='train', num_samples=3, shuffle=False)
|
||||
num_iter1 = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == 3
|
||||
|
||||
# case 2: test repeat
|
||||
data2 = ds.SBDataset(DATASET_DIR, task='Boundaries', usage='train', num_samples=4, shuffle=False)
|
||||
data2 = data2.repeat(5)
|
||||
num_iter2 = 0
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 20
|
||||
|
||||
# case 3: test batch with drop_remainder=False
|
||||
data3 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, decode=True)
|
||||
resize_op = c_vision.Resize((100, 100))
|
||||
data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
|
||||
data3 = data3.map(operations=resize_op, input_columns=["task"], num_parallel_workers=1)
|
||||
assert data3.get_dataset_size() == 4
|
||||
assert data3.get_batch_size() == 1
|
||||
data3 = data3.batch(batch_size=3) # drop_remainder is default to be False
|
||||
assert data3.get_dataset_size() == 2
|
||||
assert data3.get_batch_size() == 3
|
||||
num_iter3 = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 2
|
||||
|
||||
# case 4: test batch with drop_remainder=True
|
||||
data4 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, decode=True)
|
||||
resize_op = c_vision.Resize((100, 100))
|
||||
data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
|
||||
data4 = data4.map(operations=resize_op, input_columns=["task"], num_parallel_workers=1)
|
||||
assert data4.get_dataset_size() == 4
|
||||
assert data4.get_batch_size() == 1
|
||||
data4 = data4.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert data4.get_dataset_size() == 1
|
||||
assert data4.get_batch_size() == 3
|
||||
num_iter4 = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 1
|
||||
|
||||
|
||||
def test_sbd_sequential_sampler():
|
||||
"""
|
||||
Test SBDataset with SequentialSampler
|
||||
"""
|
||||
logger.info("Test SBDataset Op with SequentialSampler")
|
||||
num_samples = 5
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='all', sampler=sampler)
|
||||
data2 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='all', shuffle=False, num_samples=num_samples)
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(item1["task"], item2["task"])
|
||||
num_iter += 1
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_sbd_exception():
|
||||
"""
|
||||
Validate SBDataset with error parameters
|
||||
"""
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id=0,
|
||||
sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=5, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=5, shard_id=5)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=0)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=256)
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id="0")
|
||||
|
||||
|
||||
def test_sbd_usage():
|
||||
"""
|
||||
Validate SBDataset image readings
|
||||
"""
|
||||
def test_config(usage):
|
||||
try:
|
||||
data = ds.SBDataset(DATASET_DIR, task='Segmentation', usage=usage)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config("train") == 4
|
||||
assert test_config("train_noval") == 4
|
||||
assert test_config("val") == 2
|
||||
assert test_config("all") == 6
|
||||
assert "usage is not within the valid set of ['train', 'val', 'train_noval', 'all']" in test_config("invalid")
|
||||
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sbd_basic01()
|
||||
test_sbd_basic02()
|
||||
test_sbd_sequential_sampler()
|
||||
test_sbd_exception()
|
||||
test_sbd_usage()
|
Loading…
Reference in New Issue