forked from mindspore-Ecosystem/mindspore
!30810 [MD] add OBSMindDataset
Merge pull request !30810 from liyong126/add_obs_mindrecord_dataset
This commit is contained in:
commit
9fcab9184e
|
@ -30,7 +30,10 @@ import mindspore._c_dataengine as cde
|
|||
from mindspore import log as logger
|
||||
from .datasets import UnionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema, \
|
||||
shuffle_to_shuffle_mode, shuffle_to_bool
|
||||
from .validators import check_minddataset, check_tfrecorddataset, check_csvdataset
|
||||
from .datasets_user_defined import GeneratorDataset
|
||||
from .obs.obs_mindrecord_dataset import MindRecordFromOBS
|
||||
from .validators import check_csvdataset, check_minddataset, check_tfrecorddataset, check_obsminddataset
|
||||
|
||||
|
||||
from ..core.validator_helpers import replace_none
|
||||
from . import samplers
|
||||
|
@ -324,3 +327,88 @@ class TFRecordDataset(SourceDataset, UnionBaseDataset):
|
|||
schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema
|
||||
return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, self.num_samples, self.shuffle_flag,
|
||||
self.num_shards, self.shard_id, self.shard_equal_rows)
|
||||
|
||||
|
||||
class OBSMindDataset(GeneratorDataset):
|
||||
"""
|
||||
|
||||
A source dataset that reads and parses MindRecord dataset which stored in OBS.
|
||||
|
||||
The columns of generated dataset depend on the source MindRecord files.
|
||||
|
||||
Args:
|
||||
|
||||
dataset_files (list[str]): List of files in OBS to be read and file path is in
|
||||
the format of s3://.
|
||||
server (str): Endpoint for accessing OBS. For example: <https://your-endpoint:9000>.
|
||||
ak (str): Access key ID of OBS.
|
||||
sk (str): Secret key ID of OBS.
|
||||
sync_obs_path (str): OBS dir path used for synchronization, users need to
|
||||
create it on OBS in advance. Path is in the format of s3://.
|
||||
column_list (list[str], optional): List of columns to be read (default=None, read all columns).
|
||||
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
|
||||
(default=Shuffle.GLOBAL).
|
||||
If shuffle is False, no shuffling will be performed;
|
||||
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||
Otherwise, there are two levels of shuffling:
|
||||
|
||||
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||
|
||||
- Shuffle.FILES: Shuffle files only.
|
||||
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
shard_equal_rows (bool, optional): Get equal rows for all shards(default=True). If shard_equal_rows
|
||||
is false, number of rows of each shard may be not equal, and may lead to a failure in distributed training.
|
||||
When the number of samples of per MindRecord file are not equal, it is suggested to set to true.
|
||||
This argument should only be specified when num_shards is also specified.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `sync_obs_path` do not exist.
|
||||
ValueError: If `column_list` is invalid.
|
||||
RuntimeError: If `num_shards` is specified but `shard_id` is None.
|
||||
RuntimeError: If `shard_id` is specified but `num_shards` is None.
|
||||
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
|
||||
|
||||
Note:
|
||||
- It's necessary to create a synchronization directory on OBS in
|
||||
advance which be defined by parameter: `sync_obs_path` .
|
||||
- If training is offline(no cloud), it's recommended to set the
|
||||
environment variable `BATCH_JOB_ID`.
|
||||
- In distributed training, if there are multiple nodes(servers), all 8
|
||||
devices must be used in each node(server). If there is only one
|
||||
node(server), there is no such restriction.
|
||||
|
||||
Examples:
|
||||
>>> dataset_obs_dir = ["s3://path/to/obs_dataset_file_1", "s3://path/to/obs_dataset_file_2"]
|
||||
>>> sync_obs_dir = "s3://sync-dir"
|
||||
>>> dataset = ds.MindDataset(dataset_obs_dir, "https://your-endpoint:9000", "AK of OBS", "SK of OBS",
|
||||
... sync_obs_dir, shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||
|
||||
"""
|
||||
|
||||
@check_obsminddataset
|
||||
def __init__(self, dataset_files, server, ak, sk, sync_obs_path,
|
||||
column_list=None,
|
||||
shuffle=Shuffle.GLOBAL,
|
||||
num_shards=None,
|
||||
shard_id=None,
|
||||
shard_equal_rows=True):
|
||||
|
||||
from .obs.config_loader import config
|
||||
config.AK = ak
|
||||
config.SK = sk
|
||||
config.SERVER = server
|
||||
config.SYNC_OBS_PATH = sync_obs_path
|
||||
dataset = MindRecordFromOBS(dataset_files, column_list, shuffle, num_shards, shard_id,
|
||||
shard_equal_rows, config.DATASET_LOCAL_PATH)
|
||||
if not column_list:
|
||||
column_list = dataset.get_col_names()
|
||||
else:
|
||||
full_column_list = dataset.get_col_names()
|
||||
if not set(column_list).issubset(full_column_list):
|
||||
raise ValueError("columns_list: {} can not found in MindRecord fields: {}".format(column_list,
|
||||
full_column_list))
|
||||
super().__init__(source=dataset, column_names=column_list, num_shards=None, shard_id=None, shuffle=False)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Introduction to dataset/engine/obs:
|
||||
|
||||
dataset/engine/obs provides the implement of OBSMindDataset.
|
||||
"""
|
||||
|
||||
from .obs_mindrecord_dataset import MindRecordFromOBS, sync_wait_for_dataset
|
||||
|
||||
__all__ = ["MindRecordFromOBS", "sync_wait_for_dataset"]
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
This dataset module provides internal global variables for OBSMindDataset API.
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
class Config():
|
||||
""" Config class for OBSMindDataset """
|
||||
|
||||
WORKING_PATH = "/cache"
|
||||
DATASET_LOCAL_PATH = os.path.join(WORKING_PATH, "dataset")
|
||||
DISK_THRESHOLD = 0.75
|
||||
TASK_NUM = 8
|
||||
PART_SIZE = 10*1024*1024
|
||||
MAX_RETRY = 3
|
||||
RETRY_DELTA_TIME = 10
|
||||
|
||||
WARMINGUP_TIME = 10 # warmup time
|
||||
WAIT_STEP_TIME = 0.1 # wait time when cache miss
|
||||
WAIT_META_TIME = 1
|
||||
SEED = 1234
|
||||
|
||||
|
||||
class _Config:
|
||||
""" Internal class that get and set global variables. """
|
||||
|
||||
def __init__(self):
|
||||
self.config = dict((k, v) for k, v in Config.__dict__.items(
|
||||
) if not callable(v) and not k.startswith('__'))
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
if key in self.config:
|
||||
return self.config[key]
|
||||
return None
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'config':
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
self.config[key] = value
|
||||
|
||||
|
||||
config = _Config()
|
|
@ -0,0 +1,495 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
The dataset module provide the internal Dataset API which load mindrecord files from OBS.
|
||||
"""
|
||||
|
||||
|
||||
import math
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
from multiprocessing.managers import SyncManager
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..datasets import Shuffle
|
||||
from ...core.config import set_seed
|
||||
|
||||
|
||||
class Manager(SyncManager):
|
||||
pass
|
||||
|
||||
|
||||
def get_manager():
|
||||
""" PriorityQueue that cross threads."""
|
||||
|
||||
Manager.register("PriorityQueue", queue.PriorityQueue)
|
||||
m = Manager()
|
||||
m.start()
|
||||
return m
|
||||
|
||||
|
||||
def init_cache_and_working_queue(cache, q, shard_files, local_path):
|
||||
"""
|
||||
Initialize the downloading queue and local cache which store the status of local dataset file.
|
||||
"""
|
||||
|
||||
from .util import init_cache_and_queue
|
||||
|
||||
idx = 0
|
||||
for shard_file, _, _, is_full_dataset in shard_files:
|
||||
dataset_file = os.path.basename(shard_file)
|
||||
path = os.path.join(local_path, dataset_file)
|
||||
init_cache_and_queue(cache, q, path, shard_file,
|
||||
idx, is_full_dataset, lock_file=dataset_file)
|
||||
idx += 1
|
||||
return cache, q
|
||||
|
||||
|
||||
def remove_unused_dataset(local_path, num_shards, shard_id, epoch_num):
|
||||
""" Rank(rank_id mod 8 equal to 0) remove all dataset files. """
|
||||
|
||||
from .config_loader import config
|
||||
|
||||
if not num_shards:
|
||||
return
|
||||
# if num_shards less than or equal to 8, assume that there is only one node(server) and
|
||||
# the dataset does not need to be removed.
|
||||
if num_shards <= 8 or shard_id % 8 != 0:
|
||||
return
|
||||
|
||||
sync_dir = '/cache/sync_data/' + str(epoch_num)
|
||||
while True:
|
||||
if os.path.exists(sync_dir) and len(os.listdir(sync_dir)) >= min(num_shards - 1, 7):
|
||||
break
|
||||
time.sleep(config.WARMINGUP_TIME)
|
||||
logger.info("[{} FUNCTION] Shard: {} wait for other rank ready in epoch: {}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, epoch_num)) # pylint: disable=W0212
|
||||
|
||||
files = os.listdir(local_path)
|
||||
for dataset_file in files:
|
||||
if dataset_file.endswith('.db'):
|
||||
continue
|
||||
dataset_path = os.path.join(local_path, dataset_file)
|
||||
os.remove(dataset_path)
|
||||
|
||||
for ready_file in os.listdir(sync_dir):
|
||||
os.remove(os.path.join(sync_dir, ready_file))
|
||||
|
||||
|
||||
def wait_remove_datset(num_shards, shard_id, epoch_num):
|
||||
""" Rank(rank_id mod 8 not equal to 0) wait for removing dataset files. """
|
||||
|
||||
from .config_loader import config
|
||||
|
||||
if not num_shards:
|
||||
return
|
||||
if num_shards <= 8 or shard_id % 8 == 0:
|
||||
return
|
||||
|
||||
sync_dir = '/cache/sync_data/' + str(epoch_num)
|
||||
|
||||
if not os.path.exists(sync_dir):
|
||||
try:
|
||||
os.makedirs(sync_dir)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
sync_file = os.path.join(sync_dir, 'ready_' + str(shard_id))
|
||||
with open(sync_file, 'w') as f:
|
||||
f.write('ok')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_dir) and not os.listdir(sync_dir):
|
||||
break
|
||||
time.sleep(config.WARMINGUP_TIME)
|
||||
logger.info("[{} FUNCTION] Shard: {} wait for removing dataset files in epoch: {}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, epoch_num)) # pylint: disable=W0212
|
||||
|
||||
|
||||
def init_shard_files(dataset_files, shuffle, seed, num_shards, shard_id, shard_equal_rows,
|
||||
size_per_shard, local_path, current_epoch):
|
||||
""" Calculate the dataset files required by each sharding and the corresponding index. """
|
||||
|
||||
from .config_loader import config
|
||||
from .util import detect_all_meta_files, fetch_meta_files, make_dataset_tuple, make_shard_files, make_shard_samples
|
||||
|
||||
shard_files = None
|
||||
if shuffle is False:
|
||||
pass
|
||||
else:
|
||||
set_seed(seed)
|
||||
random.shuffle(dataset_files)
|
||||
if num_shards: # distributed training
|
||||
# As each sharding has the same number of samples, need to fetch all meta files.
|
||||
if shard_equal_rows:
|
||||
if size_per_shard is None:
|
||||
if shard_id % 8 == 0:
|
||||
fetch_meta_files(dataset_files, local_path, shard_id)
|
||||
else:
|
||||
while detect_all_meta_files(dataset_files, local_path) is False:
|
||||
time.sleep(config.WAIT_META_TIME)
|
||||
full_dataset_size, dataset_file_size_list = make_dataset_tuple(
|
||||
dataset_files, local_path)
|
||||
size_per_shard = math.ceil(full_dataset_size / num_shards)
|
||||
shard_files = make_shard_samples(
|
||||
dataset_file_size_list, size_per_shard, shard_id)
|
||||
else:
|
||||
shard_files = make_shard_files(dataset_files, num_shards, shard_id)
|
||||
else:
|
||||
shard_files = [(dataset_file, -1, -1, True)
|
||||
for dataset_file in dataset_files]
|
||||
logger.info("[{} FUNCTION] Shard: {} expect dataset: {} in epoch: {}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, shard_files, current_epoch)) # pylint: disable=W0212
|
||||
return shard_files, size_per_shard
|
||||
|
||||
|
||||
def download_work(shard_id, current_idx, local_path, cache, q):
|
||||
""" daemon process in backend. """
|
||||
from .config_loader import config
|
||||
from .util import try_load_from_obs, get_used_disk_per
|
||||
|
||||
while True:
|
||||
idx, dataset_file = q.get()
|
||||
used_disk = get_used_disk_per()
|
||||
while used_disk > config.DISK_THRESHOLD:
|
||||
logger.info("[{} FUNCTION] Used disk space is {}%, and the disk threshold is {}%.".format(
|
||||
sys._getframe().f_code.co_name, used_disk*100, config.DISK_THRESHOLD*100)) # pylint: disable=W0212
|
||||
retry_cnt = 0
|
||||
has_deleted = delete_candidate_datasets(
|
||||
current_idx.value, idx, cache, q, local_path)
|
||||
while not has_deleted:
|
||||
if retry_cnt > config.MAX_RETRY:
|
||||
logger.warning("Delete operation retries times {} has exceeded threshold {}, "
|
||||
"please clear enough disk space.".format(retry_cnt, config.MAX_RETRY))
|
||||
has_deleted = delete_candidate_datasets(
|
||||
current_idx.value, idx, cache, q, local_path)
|
||||
retry_cnt += 1
|
||||
time.sleep(config.RETRY_DELTA_TIME)
|
||||
used_disk = get_used_disk_per()
|
||||
|
||||
logger.info("[{} FUNCTION] Shard: {} try to download: {}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, dataset_file)) # pylint: disable=W0212
|
||||
# update cache
|
||||
remote_path = os.path.dirname(dataset_file)
|
||||
dataset_file = os.path.basename(dataset_file)
|
||||
_, is_shared = cache[dataset_file]
|
||||
try_load_from_obs(remote_path, dataset_file, local_path, shard_id)
|
||||
cache[dataset_file] = (idx, is_shared)
|
||||
logger.info("[{} FUNCTION] Shard: {} finish to download: {}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, dataset_file)) # pylint: disable=W0212
|
||||
|
||||
|
||||
def delete_candidate_datasets(current_idx, queue_top_idx, cache, q, local_path):
|
||||
"""
|
||||
1. Try to delete all the datasets which have been loaded during the epoch.
|
||||
2. Otherwise, try to delete a low priority dataset in the epoch.
|
||||
3. As soon as the low priority data is deleted, it is placed in the download queue.
|
||||
"""
|
||||
|
||||
used_datasets = []
|
||||
low_priority_dataset = ''
|
||||
max_idx = -1
|
||||
delete = False
|
||||
for k, v in cache.items():
|
||||
idx, is_shared = v
|
||||
if is_shared is False and idx >= 0:
|
||||
if idx > max_idx:
|
||||
max_idx = idx
|
||||
low_priority_dataset = k
|
||||
if idx < current_idx:
|
||||
used_datasets.append(k)
|
||||
for used_dataset in used_datasets:
|
||||
dataset_path = os.path.join(local_path, used_dataset)
|
||||
if not os.path.exists(dataset_path):
|
||||
continue
|
||||
# update cache
|
||||
idx, is_shared = cache[used_dataset]
|
||||
cache[used_dataset] = (-1, is_shared)
|
||||
os.remove(dataset_path)
|
||||
delete = True
|
||||
logger.info("[{} FUNCTION] Delete used dataset file: {} and update the cache.".format(
|
||||
sys._getframe().f_code.co_name, used_dataset)) # pylint: disable=W0212
|
||||
|
||||
if delete:
|
||||
return True
|
||||
if max_idx <= current_idx or max_idx <= queue_top_idx:
|
||||
return False
|
||||
dataset_path = os.path.join(local_path, low_priority_dataset)
|
||||
if not os.path.exists(dataset_path):
|
||||
return False
|
||||
# update cache
|
||||
idx, is_shared = cache[low_priority_dataset]
|
||||
cache[low_priority_dataset] = (-1, is_shared)
|
||||
os.remove(dataset_path)
|
||||
q.put((idx, low_priority_dataset))
|
||||
logger.info("[{} FUNCTION] Delete low priority dataset file: {} and update the cache.".format(
|
||||
sys._getframe().f_code.co_name, low_priority_dataset)) # pylint: disable=W0212
|
||||
return True
|
||||
|
||||
|
||||
def _sync_up_for_obs_mindrecord_dataset(rank_id, current_epoch):
|
||||
""" Upload the synchronization file to OBS. """
|
||||
|
||||
from .config_loader import config
|
||||
from .util import file_upload_to_obs
|
||||
|
||||
sync_info = "download_dataset"
|
||||
job_id = os.environ.get('BATCH_JOB_ID', 'unknown')
|
||||
ready_file_name = sync_info + '_ready_' + str(rank_id) + '.txt'
|
||||
ready_dir = os.path.join(job_id, str(current_epoch) + "/")
|
||||
|
||||
file_upload_to_obs(config.SYNC_OBS_PATH, ready_dir, ready_file_name)
|
||||
logger.info("[{} FUNCTION] Current rank:{}'s sync file:{} is ready for epoch:{}.".format(
|
||||
sys._getframe().f_code.co_name, rank_id, os.path.join(ready_dir, ready_file_name), current_epoch)) # pylint: disable=W0212
|
||||
|
||||
|
||||
def sync_wait_for_dataset(rank_id, rank_size, current_epoch):
|
||||
"""
|
||||
Wait util the dataset files required by all devices are downloaded.
|
||||
|
||||
Note:
|
||||
It should be used together with `mindspore.dataset.OBSMindDataset` and
|
||||
be called before each epoch.
|
||||
|
||||
Args:
|
||||
rank_id(int): Rank ID of the device.
|
||||
rank_size(int): Rank size.
|
||||
current_epoch(int): Number of current epochs.
|
||||
|
||||
Examples:
|
||||
>>> # Create a synchronization callback
|
||||
>>>
|
||||
>>> from mindspore.dataset import sync_wait_for_dataset
|
||||
>>> from mindspore.train.callback import Callback
|
||||
>>>
|
||||
>>> class SyncForDataset(Callback):
|
||||
... def __init__(self):
|
||||
... super(SyncForDataset, self).__init__()
|
||||
... def epoch_begin(self, run_context):
|
||||
... cb_params = run_context.original_args()
|
||||
... epoch_num = cb_params.cur_epoch_num
|
||||
... sync_wait_for_dataset(rank_id, rank_size, epoch_num)
|
||||
|
||||
"""
|
||||
|
||||
from .config_loader import config
|
||||
from .util import obsClient, get_bucket_and_key
|
||||
|
||||
bucket_name, object_key = get_bucket_and_key(config.SYNC_OBS_PATH)
|
||||
|
||||
job_id = os.environ.get('BATCH_JOB_ID', 'unknown')
|
||||
ready_dir = os.path.join(object_key, job_id, str(current_epoch) + "/")
|
||||
|
||||
success = False
|
||||
while True:
|
||||
if success:
|
||||
break
|
||||
try:
|
||||
# no guarantee that the dir is included.
|
||||
resp = obsClient.listObjects(bucket_name, prefix=ready_dir)
|
||||
if resp.status < 300:
|
||||
ready_num = 0
|
||||
for content in resp.body.contents:
|
||||
if content.key.endswith(".txt"):
|
||||
ready_num += 1
|
||||
if ready_num >= rank_size:
|
||||
success = True
|
||||
else:
|
||||
logger.warning("[{} FUNCTION] OBS SDK errorCode:{}, errMsg: {}.".format(
|
||||
sys._getframe(), resp.errorCode, resp.errorMessage)) # pylint: disable=W0212
|
||||
except Exception: # pylint: disable=W0703
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep(config.RETRY_DELTA_TIME)
|
||||
logger.info("[{} FUNCTION] Waiting for sync dir:{} and current_rank:{}, total_rank:{}, "
|
||||
"ready_rank:{} in epoch:{}.".format(sys._getframe().f_code.co_name, ready_dir, # pylint: disable=W0212
|
||||
rank_id, rank_size, ready_num, current_epoch))
|
||||
logger.info("[{} FUNCTION] Succeed to sync dir:{} and begin epoch:{}.".format(
|
||||
sys._getframe().f_code.co_name, ready_dir, current_epoch)) # pylint: disable=W0212
|
||||
|
||||
|
||||
def _sync_for_obs_mindrecord_dataset(shard_files, cache, num_shards, shard_id, current_epoch):
|
||||
""" Synchronize all shardings. """
|
||||
|
||||
from .config_loader import config
|
||||
|
||||
while True:
|
||||
dataset, _, _, _ = shard_files[-1]
|
||||
current_dataset = os.path.basename(dataset)
|
||||
hit_cache = cache[current_dataset][0]
|
||||
if hit_cache >= 0: # hit cache
|
||||
logger.info("[{} FUNCTION] Current_rank:{} has download:{} for epoch:{}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, dataset, current_epoch)) # pylint: disable=W0212
|
||||
_sync_up_for_obs_mindrecord_dataset(shard_id, current_epoch)
|
||||
break
|
||||
time.sleep(config.WARMINGUP_TIME)
|
||||
logger.info("[{} FUNCTION] Current_rank:{} wait for downloading:{} in epoch:{}.".format(
|
||||
sys._getframe().f_code.co_name, shard_id, dataset, current_epoch)) # pylint: disable=W0212
|
||||
sync_wait_for_dataset(shard_id, num_shards, current_epoch)
|
||||
|
||||
|
||||
class MindRecordFromOBS:
|
||||
""" Internal class which load remote dataset files from OBS. """
|
||||
|
||||
def __init__(self, dataset_files, columns_list, shuffle, num_shards, shard_id, shard_equal_rows, local_path):
|
||||
self._dataset_files = dataset_files
|
||||
self._columns_list = columns_list
|
||||
|
||||
self._num_shards = num_shards
|
||||
self._shard_id = shard_id
|
||||
|
||||
self._shard_equal_rows = shard_equal_rows
|
||||
self._local_path = os.path.realpath(local_path)
|
||||
|
||||
self._shuffle = Shuffle.GLOBAL if shuffle is None or shuffle is True else shuffle
|
||||
from .config_loader import config
|
||||
self._seed = config.SEED
|
||||
self._size_per_shard = None
|
||||
self._curr_epoch = 1
|
||||
self._curr_step = 1
|
||||
self._shard_files, self._size_per_shard = init_shard_files(self._dataset_files, self._shuffle, self._seed,
|
||||
self._num_shards, self._shard_id,
|
||||
self._shard_equal_rows, self._size_per_shard,
|
||||
self._local_path, self._curr_epoch)
|
||||
|
||||
m = get_manager()
|
||||
self._queue = m.PriorityQueue()
|
||||
self._cache = m.dict()
|
||||
|
||||
self._index = 0
|
||||
self._current_idx = m.Value('i', self._index)
|
||||
|
||||
self._cache, self._queue = init_cache_and_working_queue(
|
||||
self._cache, self._queue, self._shard_files, self._local_path)
|
||||
|
||||
self._index = 0
|
||||
self._first_epoch = True
|
||||
self._iteration = None
|
||||
self._cache_miss_times = 0
|
||||
|
||||
self._pool = ThreadPool(processes=1)
|
||||
self._pool.apply_async(download_work, (self._shard_id,
|
||||
self._current_idx, self._local_path, self._cache, self._queue))
|
||||
_sync_for_obs_mindrecord_dataset(
|
||||
self._shard_files, self._cache, self._num_shards, self._shard_id, self._curr_epoch)
|
||||
|
||||
def get_col_names(self):
|
||||
""" Get column names of Mindrecord format dataset."""
|
||||
|
||||
from ..datasets_standard_format import MindDataset
|
||||
|
||||
target_dataset = None
|
||||
while target_dataset is None:
|
||||
for f, _, _, _ in self._shard_files:
|
||||
current_dataset = os.path.basename(f)
|
||||
if self._cache[current_dataset][0] >= 0:
|
||||
target_dataset = current_dataset
|
||||
path = os.path.join(self._local_path, target_dataset)
|
||||
_iteration = MindDataset(dataset_files=[path], shuffle=False)
|
||||
return _iteration.get_col_names()
|
||||
|
||||
def __next__(self):
|
||||
from .config_loader import config
|
||||
from ..datasets_standard_format import MindDataset
|
||||
from .util import make_sampler
|
||||
|
||||
if self._iteration:
|
||||
try:
|
||||
self._curr_step += 1
|
||||
return next(self._iteration)
|
||||
except StopIteration:
|
||||
self._index += 1
|
||||
self._current_idx.value = self._index
|
||||
|
||||
self._iteration = None
|
||||
if self._index >= len(self._shard_files):
|
||||
self._first_epoch = False
|
||||
self._curr_epoch += 1
|
||||
self._curr_step = 0
|
||||
raise StopIteration
|
||||
return next(self)
|
||||
else:
|
||||
f, start, end, is_full_dataset = self._shard_files[self._index]
|
||||
current_dataset = os.path.basename(f)
|
||||
hit_cache = self._cache[current_dataset][0]
|
||||
if hit_cache >= 0: # hit cache
|
||||
self._cache_miss_times = 0
|
||||
# launch pipeline
|
||||
sampler = make_sampler(
|
||||
self._shuffle, is_full_dataset, start, end)
|
||||
path = os.path.join(self._local_path, current_dataset)
|
||||
logger.info("[{} FUNCTION] Shard:{} start to load dataset:{} in epoch:{}.".format(
|
||||
sys._getframe().f_code.co_name, self._shard_id, path, self._curr_epoch)) # pylint: disable=W0212
|
||||
self._iteration = MindDataset(dataset_files=[path], columns_list=self._columns_list, sampler=sampler,
|
||||
shuffle=None).create_tuple_iterator(num_epochs=1, output_numpy=True)
|
||||
|
||||
else:
|
||||
# cache miss
|
||||
self._cache_miss_times += 1
|
||||
logger.info("[{} FUNCTION] Cache miss in shard {} for times {}, expect dataset {}.".format(
|
||||
sys._getframe().f_code.co_name, self._shard_id, self._cache_miss_times, current_dataset)) # pylint: disable=W0212
|
||||
time.sleep(self._cache_miss_times * config.WAIT_STEP_TIME)
|
||||
return next(self)
|
||||
|
||||
def __iter__(self):
|
||||
if self._first_epoch:
|
||||
self._index = 0
|
||||
self._current_idx.value = self._index
|
||||
self._iteration = None
|
||||
return self
|
||||
self._index = 0
|
||||
self._current_idx.value = self._index
|
||||
|
||||
self._seed += 1
|
||||
self._iteration = None
|
||||
self._shard_files, self._size_per_shard = init_shard_files(self._dataset_files, self._shuffle, self._seed,
|
||||
self._num_shards, self._shard_id,
|
||||
self._shard_equal_rows, self._size_per_shard,
|
||||
self._local_path, self._curr_epoch)
|
||||
self._cache.clear()
|
||||
# reset queue
|
||||
try:
|
||||
while True:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
remove_unused_dataset(
|
||||
self._local_path, self._num_shards, self._shard_id, self._curr_epoch)
|
||||
wait_remove_datset(self._num_shards, self._shard_id, self._curr_epoch)
|
||||
|
||||
self._cache, self._queue = init_cache_and_working_queue(
|
||||
self._cache, self._queue, self._shard_files, self._local_path)
|
||||
_sync_for_obs_mindrecord_dataset(
|
||||
self._shard_files, self._cache, self._num_shards, self._shard_id, self._curr_epoch)
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
from .util import fetch_meta_files, make_dataset_tuple
|
||||
|
||||
if self._size_per_shard is not None:
|
||||
return self._size_per_shard
|
||||
dataset_files = []
|
||||
for dataset_file, _, _, _ in self._shard_files:
|
||||
dataset_files.append(dataset_file)
|
||||
fetch_meta_files(dataset_files, self._local_path, self._shard_id)
|
||||
self._size_per_shard, _ = make_dataset_tuple(
|
||||
dataset_files, self._local_path)
|
||||
return len(self)
|
|
@ -0,0 +1,377 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
This dataset module provides internal utility function for OBSMindDataset API.
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from functools import wraps
|
||||
from obs import ObsClient
|
||||
|
||||
from mindspore import log as logger
|
||||
from .config_loader import config
|
||||
from ..datasets import Shuffle
|
||||
from ..samplers import RandomSampler, SequentialSampler, SubsetSampler, SubsetRandomSampler
|
||||
|
||||
obsClient = ObsClient(
|
||||
access_key_id=config.AK,
|
||||
secret_access_key=config.SK,
|
||||
server=config.SERVER
|
||||
)
|
||||
|
||||
|
||||
def get_used_disk_per():
|
||||
""" Get the disk usage of working directory."""
|
||||
|
||||
if not os.path.exists(config.WORKING_PATH):
|
||||
try:
|
||||
os.makedirs(config.WORKING_PATH)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
total, used, _ = shutil.disk_usage(config.WORKING_PATH)
|
||||
return used / total
|
||||
|
||||
|
||||
def try_load_from_obs(remote_path, dataset_file, local_path, shard_id):
|
||||
""" Download all dataset files from obs, skip if it exists. """
|
||||
try:
|
||||
if not os.path.exists(os.path.join(local_path, dataset_file)):
|
||||
download_file(remote_path, dataset_file, local_path, lock_file=dataset_file)
|
||||
meta_file = dataset_file + '.db'
|
||||
if not os.path.exists(os.path.join(local_path, meta_file)):
|
||||
download_file(remote_path, meta_file, local_path, lock_file=meta_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to fetch file from obs, error: " + str(e))
|
||||
|
||||
|
||||
def detect_all_meta_files(dataset_files, local_path):
|
||||
""" Checking that all meta files exit on local. """
|
||||
|
||||
all_meta_files = True
|
||||
for f in dataset_files:
|
||||
dataset_file = os.path.basename(f)
|
||||
meta_file = dataset_file + '.db'
|
||||
if detect_file_exist(local_path, meta_file, lock_file=meta_file) is False:
|
||||
all_meta_files = False
|
||||
break
|
||||
return all_meta_files
|
||||
|
||||
|
||||
def make_sampler(shuffle, is_full_dataset, start, end):
|
||||
""" Generate a proper sampler based on inputs. """
|
||||
|
||||
sampler = None
|
||||
if shuffle == Shuffle.GLOBAL:
|
||||
if is_full_dataset:
|
||||
sampler = RandomSampler()
|
||||
else:
|
||||
sampler = SubsetRandomSampler(list(range(start, end)))
|
||||
else:
|
||||
if is_full_dataset:
|
||||
sampler = SequentialSampler()
|
||||
else:
|
||||
sampler = SubsetSampler(list(range(start, end)))
|
||||
return sampler
|
||||
|
||||
|
||||
def make_shard_samples(dataset_file_size_list, size_per_shard, shard_id):
|
||||
""" Make sharding files when shard_equal_rows is True. """
|
||||
pre_cnt = 0
|
||||
shard_files = []
|
||||
finish = False
|
||||
while finish is False:
|
||||
for f, dataset_size in dataset_file_size_list:
|
||||
start_idx = shard_id * size_per_shard
|
||||
end_idx = (shard_id + 1) * size_per_shard
|
||||
push = False
|
||||
is_full_dataset = False
|
||||
|
||||
if pre_cnt <= start_idx < pre_cnt + dataset_size:
|
||||
start = start_idx - pre_cnt
|
||||
push = True
|
||||
if pre_cnt < end_idx <= pre_cnt + dataset_size:
|
||||
end = end_idx - pre_cnt
|
||||
else:
|
||||
end = dataset_size
|
||||
|
||||
if start_idx <= pre_cnt < end_idx:
|
||||
start = 0
|
||||
push = True
|
||||
if pre_cnt + dataset_size >= end_idx:
|
||||
end = end_idx - pre_cnt
|
||||
else:
|
||||
end = dataset_size
|
||||
|
||||
if push:
|
||||
if start == 0 and end == dataset_size:
|
||||
is_full_dataset = True
|
||||
shard_files.append((f, start, end, is_full_dataset))
|
||||
pre_cnt += dataset_size
|
||||
if pre_cnt >= (shard_id + 1) * size_per_shard:
|
||||
finish = True
|
||||
return shard_files
|
||||
|
||||
|
||||
def make_dataset_tuple(dataset_files, local_path):
|
||||
""" Calculates the total size of the dataset and the size of each dataset file """
|
||||
|
||||
dataset_file_size_list = []
|
||||
dataset_size = 0
|
||||
|
||||
for dataset_file in dataset_files:
|
||||
meta_file = os.path.basename(dataset_file) + '.db'
|
||||
path = os.path.join(local_path, meta_file)
|
||||
try:
|
||||
conn = sqlite3.connect(path)
|
||||
c = conn.cursor()
|
||||
cursor = c.execute("SELECT COUNT(*) FROM INDEXES")
|
||||
for row in cursor:
|
||||
dataset_size += row[0]
|
||||
dataset_file_size_list.append((dataset_file, row[0]))
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to get dataset size from metadata, err: " + str(e))
|
||||
return dataset_size, dataset_file_size_list
|
||||
|
||||
|
||||
def fetch_meta_files(dataset_files, local_path, shard_id):
|
||||
""" Download all meta files from obs, skip if it exists"""
|
||||
|
||||
try:
|
||||
for df in dataset_files:
|
||||
dataset_file = os.path.basename(df)
|
||||
meta_file = dataset_file + '.db'
|
||||
remote_path = os.path.dirname(df)
|
||||
download_file(remote_path, meta_file, local_path, lock_file=meta_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to fetch meta file from OBS, error: " + str(e))
|
||||
|
||||
|
||||
def make_shard_files(dataset_files, num_shards, shard_id):
|
||||
""" Make sharding files when shard_equal_rows is False. """
|
||||
|
||||
idx = 0
|
||||
shard_files = []
|
||||
|
||||
for dataset_file in dataset_files:
|
||||
if idx % num_shards == shard_id:
|
||||
shard_files.append((dataset_file, -1, -1, True))
|
||||
idx += 1
|
||||
return shard_files
|
||||
|
||||
|
||||
def get_bucket_and_key(obs_path):
|
||||
r"""
|
||||
split obs path to bucket name and object key.
|
||||
|
||||
Args:
|
||||
obs_path: obs path that starts with s3://.
|
||||
|
||||
Returns:
|
||||
(str, str), bucketName and objectKey.
|
||||
"""
|
||||
|
||||
start = obs_path.find('//')
|
||||
end = obs_path.find('/', start + 2)
|
||||
if end == -1:
|
||||
return obs_path[start + 2:], ""
|
||||
return obs_path[start + 2:end], obs_path[end + 1:]
|
||||
|
||||
|
||||
def exclusive_lock(func):
|
||||
""" Decorator that execute func under exclusive lock. """
|
||||
|
||||
@wraps(func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
try:
|
||||
lock_file = os.path.join('/tmp/', '{}.lock'.format(kwargs['lock_file']))
|
||||
except KeyError:
|
||||
raise RuntimeError("Lock file can not found in function {}.".format(func_name))
|
||||
with open(lock_file, 'w') as fd:
|
||||
retry_cnt = 0
|
||||
success = False
|
||||
while True:
|
||||
if success:
|
||||
break
|
||||
try:
|
||||
if retry_cnt > config.MAX_RETRY:
|
||||
raise RuntimeError("Function {} retries times {} has exceeded threshold {}.".format(
|
||||
func_name, retry_cnt, config.MAX_RETRY))
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
success = True
|
||||
result = func(*args, **kwargs)
|
||||
except RuntimeError as e:
|
||||
raise e
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
retry_cnt += 1
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep(config.RETRY_DELTA_TIME)
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
return result
|
||||
return wrapped_func
|
||||
|
||||
|
||||
def retry_execute(func):
|
||||
""" Decorator that retry on unexpected errors. """
|
||||
|
||||
func_name = func.__name__
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
retry_cnt = 0
|
||||
success = False
|
||||
while True:
|
||||
if success:
|
||||
break
|
||||
try:
|
||||
if retry_cnt > config.MAX_RETRY:
|
||||
err_msg = "Function {} retries times {} has exceeded threshold {}.".format(
|
||||
func_name, retry_cnt, config.MAX_RETRY)
|
||||
logger.error(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
result = func(*args, **kwargs)
|
||||
success = True
|
||||
except RuntimeError as e:
|
||||
raise e
|
||||
except Exception: # pylint: disable=W0703
|
||||
retry_cnt += 1
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep(config.RETRY_DELTA_TIME)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
@retry_execute
|
||||
def check_file_exists_in_obs(obs_path):
|
||||
""" Detect that file exists in obs. """
|
||||
|
||||
bucket_name, object_key = get_bucket_and_key(obs_path)
|
||||
resp = obsClient.getObjectMetadata(bucket_name, object_key)
|
||||
if resp.status < 300:
|
||||
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
|
||||
sys._getframe(), resp.requestId)) # pylint: disable=W0212
|
||||
else:
|
||||
err_msg = "File {} not found in OBS, please check again.".format(obs_path)
|
||||
logger.error(err_msg)
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
|
||||
@retry_execute
|
||||
def file_download_from_obs(obs_path, local_path):
|
||||
""" Download file from OBS. """
|
||||
|
||||
bucket_name, object_key = get_bucket_and_key(obs_path)
|
||||
downloadFile = local_path
|
||||
taskNum = config.TASK_NUM
|
||||
partSize = config.PART_SIZE
|
||||
enableCheckpoint = True
|
||||
|
||||
resp = obsClient.downloadFile(
|
||||
bucket_name, object_key, downloadFile, partSize, taskNum, enableCheckpoint)
|
||||
if resp.status < 300:
|
||||
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
|
||||
sys._getframe(), resp.requestId)) # pylint: disable=W0212
|
||||
else:
|
||||
raise Exception("OBS SDK errorCode:{}, errMsg: {}.".format(
|
||||
resp.errorCode, resp.errorMessage))
|
||||
|
||||
|
||||
@exclusive_lock
|
||||
def download_file(remote_path, object_name, des_path, lock_file='tmp'):
|
||||
""" Download file from OBS exclusively. """
|
||||
|
||||
local_path = os.path.join(des_path, object_name)
|
||||
if os.path.exists(local_path):
|
||||
return
|
||||
if not os.path.exists(des_path):
|
||||
os.makedirs(des_path)
|
||||
obs_path = os.path.join(remote_path, object_name)
|
||||
|
||||
check_file_exists_in_obs(obs_path)
|
||||
file_download_from_obs(obs_path, local_path)
|
||||
|
||||
|
||||
@exclusive_lock
|
||||
def init_cache_and_queue(cache, q, path, shard_file, idx, is_full_dataset, lock_file='tmp'):
|
||||
""" Initialize cache and queue according to the status of local dataset files."""
|
||||
|
||||
dataset_file = os.path.basename(shard_file)
|
||||
if os.path.exists(path): # found in local
|
||||
logger.info("[{} FUNCTION] Push dataset file {} to cache.".format(
|
||||
sys._getframe(), dataset_file)) # pylint: disable=W0212
|
||||
cache[dataset_file] = (idx, not is_full_dataset)
|
||||
else:
|
||||
logger.info("[{} FUNCTION] Push dataset file {} to downloading queue.".format(
|
||||
sys._getframe(), dataset_file)) # pylint: disable=W0212
|
||||
cache[dataset_file] = (-1, not is_full_dataset)
|
||||
q.put((idx, shard_file))
|
||||
|
||||
|
||||
@exclusive_lock
|
||||
def detect_file_exist(local_path, meta_file, lock_file='tmp'):
|
||||
""" Detect that local dataset file exists or not. """
|
||||
if os.path.exists(os.path.join(local_path, meta_file)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@retry_execute
|
||||
def file_upload_to_obs(obs_path, sync_dir, ready_file_name):
|
||||
""" Upload sync file to OBS. """
|
||||
|
||||
bucket_name, object_key = get_bucket_and_key(obs_path)
|
||||
|
||||
if not object_key:
|
||||
resp = obsClient.headBucket(bucket_name)
|
||||
else:
|
||||
if not object_key.endswith("/"):
|
||||
object_key += "/"
|
||||
resp = obsClient.getObjectMetadata(bucket_name, object_key)
|
||||
|
||||
if resp.status < 300:
|
||||
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
|
||||
sys._getframe(), resp.requestId)) # pylint: disable=W0212
|
||||
else:
|
||||
raise RuntimeError("Directory/Bucket used for synchronization {} is not found in OBS, " \
|
||||
"please create it on OBS first.".format(obs_path))
|
||||
|
||||
remote_dir = os.path.join(object_key, sync_dir)
|
||||
resp = obsClient.putContent(bucket_name, remote_dir, content=None)
|
||||
if resp.status < 300:
|
||||
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
|
||||
sys._getframe(), resp.requestId)) # pylint: disable=W0212
|
||||
else:
|
||||
raise Exception("OBS SDK errorCode:{}, errMsg: {}.".format(
|
||||
resp.errorCode, resp.errorMessage))
|
||||
resp = obsClient.putContent(bucket_name, os.path.join(
|
||||
remote_dir, ready_file_name), content='OK')
|
||||
if resp.status < 300:
|
||||
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
|
||||
sys._getframe(), resp.requestId)) # pylint: disable=W0212
|
||||
else:
|
||||
raise Exception("OBS SDK errorCode:{}, errMsg: {}.".format(
|
||||
resp.errorCode, resp.errorMessage))
|
|
@ -2814,3 +2814,36 @@ def check_multi30k_dataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
def check_obsminddataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(OBSMindDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_shards', 'shard_id']
|
||||
nreq_param_list = ['columns_list']
|
||||
nreq_param_bool = ['shard_equal_rows']
|
||||
nreq_param_str = ['server', 'ak', 'sk', 'sync_obs_path']
|
||||
|
||||
dataset_files = param_dict.get('dataset_files')
|
||||
type_check(dataset_files, (list,), "dataset files")
|
||||
for dataset_file in dataset_files:
|
||||
if not isinstance(dataset_file, str):
|
||||
raise TypeError("Item of dataset files is not of type [{}], but got {}.".format(type(''),
|
||||
type(dataset_file)))
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
||||
|
||||
server = param_dict.get('server')
|
||||
if not server.startswith(('http://', 'https://')):
|
||||
raise ValueError("server should be a str that starts with http:// or https://, but got {}.".format(server))
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test OBSMindDataset operator
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from mindspore.dataset.engine.datasets_standard_format import OBSMindDataset
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["s3://dataset/imagenet0", "s3://dataset/imagenet1"]
|
||||
|
||||
|
||||
def test_obs_mindrecord_exception():
|
||||
"""
|
||||
Feature: Test OBSMindDataset.
|
||||
Description: invalid input.
|
||||
Expectation: raise exception.
|
||||
"""
|
||||
|
||||
logger.info("Test error cases for MnistDataset")
|
||||
error_msg_0 = "Argument dataset files"
|
||||
with pytest.raises(TypeError, match=error_msg_0):
|
||||
OBSMindDataset("err_dataset", "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
|
||||
|
||||
error_msg_0_1 = "Item of dataset files"
|
||||
with pytest.raises(TypeError, match=error_msg_0_1):
|
||||
OBSMindDataset([1, 2], "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
|
||||
|
||||
error_msg_1 = "Argument server"
|
||||
with pytest.raises(TypeError, match=error_msg_1):
|
||||
OBSMindDataset(DATA_DIR, 12, "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
|
||||
|
||||
error_msg_1_1 = "server should"
|
||||
with pytest.raises(ValueError, match=error_msg_1_1):
|
||||
OBSMindDataset(DATA_DIR, "ftp://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
|
||||
|
||||
error_msg_2 = "Argument ak"
|
||||
with pytest.raises(TypeError, match=error_msg_2):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", 12, "dummy_sk", "s3://dummy_sync_dir")
|
||||
|
||||
error_msg_3 = "Argument sk"
|
||||
with pytest.raises(TypeError, match=error_msg_3):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", 12, "s3://dummy_sync_dir")
|
||||
|
||||
error_msg_4 = "Argument sync_obs_path"
|
||||
with pytest.raises(TypeError, match=error_msg_4):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", "dummy_sk", 12)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
|
||||
"dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=-1)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
|
||||
"dummy_sk", "s3://dummy_sync_dir", num_shards=4, shard_id=4)
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
|
||||
"dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=4)
|
||||
|
||||
error_msg_7 = "Argument shard_equal_rows"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
|
||||
"dummy_sk", "s3://dummy_sync_dir", shard_equal_rows=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_obs_mindrecord_exception()
|
Loading…
Reference in New Issue