!30810 [MD] add OBSMindDataset

Merge pull request !30810 from liyong126/add_obs_mindrecord_dataset
This commit is contained in:
i-robot 2022-03-10 07:28:46 +00:00 committed by Gitee
commit 9fcab9184e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 1156 additions and 1 deletions

View File

@ -30,7 +30,10 @@ import mindspore._c_dataengine as cde
from mindspore import log as logger
from .datasets import UnionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema, \
shuffle_to_shuffle_mode, shuffle_to_bool
from .validators import check_minddataset, check_tfrecorddataset, check_csvdataset
from .datasets_user_defined import GeneratorDataset
from .obs.obs_mindrecord_dataset import MindRecordFromOBS
from .validators import check_csvdataset, check_minddataset, check_tfrecorddataset, check_obsminddataset
from ..core.validator_helpers import replace_none
from . import samplers
@ -324,3 +327,88 @@ class TFRecordDataset(SourceDataset, UnionBaseDataset):
schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema
return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, self.num_samples, self.shuffle_flag,
self.num_shards, self.shard_id, self.shard_equal_rows)
class OBSMindDataset(GeneratorDataset):
"""
A source dataset that reads and parses MindRecord dataset which stored in OBS.
The columns of generated dataset depend on the source MindRecord files.
Args:
dataset_files (list[str]): List of files in OBS to be read and file path is in
the format of s3://.
server (str): Endpoint for accessing OBS. For example: <https://your-endpoint:9000>.
ak (str): Access key ID of OBS.
sk (str): Secret key ID of OBS.
sync_obs_path (str): OBS dir path used for synchronization, users need to
create it on OBS in advance. Path is in the format of s3://.
column_list (list[str], optional): List of columns to be read (default=None, read all columns).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
shard_equal_rows (bool, optional): Get equal rows for all shards(default=True). If shard_equal_rows
is false, number of rows of each shard may be not equal, and may lead to a failure in distributed training.
When the number of samples of per MindRecord file are not equal, it is suggested to set to true.
This argument should only be specified when num_shards is also specified.
Raises:
RuntimeError: If `sync_obs_path` do not exist.
ValueError: If `column_list` is invalid.
RuntimeError: If `num_shards` is specified but `shard_id` is None.
RuntimeError: If `shard_id` is specified but `num_shards` is None.
ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`).
Note:
- It's necessary to create a synchronization directory on OBS in
advance which be defined by parameter: `sync_obs_path` .
- If training is offline(no cloud), it's recommended to set the
environment variable `BATCH_JOB_ID`.
- In distributed training, if there are multiple nodes(servers), all 8
devices must be used in each node(server). If there is only one
node(server), there is no such restriction.
Examples:
>>> dataset_obs_dir = ["s3://path/to/obs_dataset_file_1", "s3://path/to/obs_dataset_file_2"]
>>> sync_obs_dir = "s3://sync-dir"
>>> dataset = ds.MindDataset(dataset_obs_dir, "https://your-endpoint:9000", "AK of OBS", "SK of OBS",
... sync_obs_dir, shuffle=True, num_shards=num_shards, shard_id=shard_id)
"""
@check_obsminddataset
def __init__(self, dataset_files, server, ak, sk, sync_obs_path,
column_list=None,
shuffle=Shuffle.GLOBAL,
num_shards=None,
shard_id=None,
shard_equal_rows=True):
from .obs.config_loader import config
config.AK = ak
config.SK = sk
config.SERVER = server
config.SYNC_OBS_PATH = sync_obs_path
dataset = MindRecordFromOBS(dataset_files, column_list, shuffle, num_shards, shard_id,
shard_equal_rows, config.DATASET_LOCAL_PATH)
if not column_list:
column_list = dataset.get_col_names()
else:
full_column_list = dataset.get_col_names()
if not set(column_list).issubset(full_column_list):
raise ValueError("columns_list: {} can not found in MindRecord fields: {}".format(column_list,
full_column_list))
super().__init__(source=dataset, column_names=column_list, num_shards=None, shard_id=None, shuffle=False)

View File

@ -0,0 +1,23 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Introduction to dataset/engine/obs:
dataset/engine/obs provides the implement of OBSMindDataset.
"""
from .obs_mindrecord_dataset import MindRecordFromOBS, sync_wait_for_dataset
__all__ = ["MindRecordFromOBS", "sync_wait_for_dataset"]

View File

@ -0,0 +1,59 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This dataset module provides internal global variables for OBSMindDataset API.
"""
import os
class Config():
""" Config class for OBSMindDataset """
WORKING_PATH = "/cache"
DATASET_LOCAL_PATH = os.path.join(WORKING_PATH, "dataset")
DISK_THRESHOLD = 0.75
TASK_NUM = 8
PART_SIZE = 10*1024*1024
MAX_RETRY = 3
RETRY_DELTA_TIME = 10
WARMINGUP_TIME = 10 # warmup time
WAIT_STEP_TIME = 0.1 # wait time when cache miss
WAIT_META_TIME = 1
SEED = 1234
class _Config:
""" Internal class that get and set global variables. """
def __init__(self):
self.config = dict((k, v) for k, v in Config.__dict__.items(
) if not callable(v) and not k.startswith('__'))
def __getattr__(self, key):
if key in os.environ:
return os.environ[key]
if key in self.config:
return self.config[key]
return None
def __setattr__(self, key, value):
if key == 'config':
self.__dict__[key] = value
else:
self.config[key] = value
config = _Config()

View File

@ -0,0 +1,495 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The dataset module provide the internal Dataset API which load mindrecord files from OBS.
"""
import math
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing.managers import SyncManager
import os
import queue
import random
import sys
import time
from mindspore import log as logger
from ..datasets import Shuffle
from ...core.config import set_seed
class Manager(SyncManager):
pass
def get_manager():
""" PriorityQueue that cross threads."""
Manager.register("PriorityQueue", queue.PriorityQueue)
m = Manager()
m.start()
return m
def init_cache_and_working_queue(cache, q, shard_files, local_path):
"""
Initialize the downloading queue and local cache which store the status of local dataset file.
"""
from .util import init_cache_and_queue
idx = 0
for shard_file, _, _, is_full_dataset in shard_files:
dataset_file = os.path.basename(shard_file)
path = os.path.join(local_path, dataset_file)
init_cache_and_queue(cache, q, path, shard_file,
idx, is_full_dataset, lock_file=dataset_file)
idx += 1
return cache, q
def remove_unused_dataset(local_path, num_shards, shard_id, epoch_num):
""" Rank(rank_id mod 8 equal to 0) remove all dataset files. """
from .config_loader import config
if not num_shards:
return
# if num_shards less than or equal to 8, assume that there is only one node(server) and
# the dataset does not need to be removed.
if num_shards <= 8 or shard_id % 8 != 0:
return
sync_dir = '/cache/sync_data/' + str(epoch_num)
while True:
if os.path.exists(sync_dir) and len(os.listdir(sync_dir)) >= min(num_shards - 1, 7):
break
time.sleep(config.WARMINGUP_TIME)
logger.info("[{} FUNCTION] Shard: {} wait for other rank ready in epoch: {}.".format(
sys._getframe().f_code.co_name, shard_id, epoch_num)) # pylint: disable=W0212
files = os.listdir(local_path)
for dataset_file in files:
if dataset_file.endswith('.db'):
continue
dataset_path = os.path.join(local_path, dataset_file)
os.remove(dataset_path)
for ready_file in os.listdir(sync_dir):
os.remove(os.path.join(sync_dir, ready_file))
def wait_remove_datset(num_shards, shard_id, epoch_num):
""" Rank(rank_id mod 8 not equal to 0) wait for removing dataset files. """
from .config_loader import config
if not num_shards:
return
if num_shards <= 8 or shard_id % 8 == 0:
return
sync_dir = '/cache/sync_data/' + str(epoch_num)
if not os.path.exists(sync_dir):
try:
os.makedirs(sync_dir)
except FileExistsError:
pass
sync_file = os.path.join(sync_dir, 'ready_' + str(shard_id))
with open(sync_file, 'w') as f:
f.write('ok')
while True:
if os.path.exists(sync_dir) and not os.listdir(sync_dir):
break
time.sleep(config.WARMINGUP_TIME)
logger.info("[{} FUNCTION] Shard: {} wait for removing dataset files in epoch: {}.".format(
sys._getframe().f_code.co_name, shard_id, epoch_num)) # pylint: disable=W0212
def init_shard_files(dataset_files, shuffle, seed, num_shards, shard_id, shard_equal_rows,
size_per_shard, local_path, current_epoch):
""" Calculate the dataset files required by each sharding and the corresponding index. """
from .config_loader import config
from .util import detect_all_meta_files, fetch_meta_files, make_dataset_tuple, make_shard_files, make_shard_samples
shard_files = None
if shuffle is False:
pass
else:
set_seed(seed)
random.shuffle(dataset_files)
if num_shards: # distributed training
# As each sharding has the same number of samples, need to fetch all meta files.
if shard_equal_rows:
if size_per_shard is None:
if shard_id % 8 == 0:
fetch_meta_files(dataset_files, local_path, shard_id)
else:
while detect_all_meta_files(dataset_files, local_path) is False:
time.sleep(config.WAIT_META_TIME)
full_dataset_size, dataset_file_size_list = make_dataset_tuple(
dataset_files, local_path)
size_per_shard = math.ceil(full_dataset_size / num_shards)
shard_files = make_shard_samples(
dataset_file_size_list, size_per_shard, shard_id)
else:
shard_files = make_shard_files(dataset_files, num_shards, shard_id)
else:
shard_files = [(dataset_file, -1, -1, True)
for dataset_file in dataset_files]
logger.info("[{} FUNCTION] Shard: {} expect dataset: {} in epoch: {}.".format(
sys._getframe().f_code.co_name, shard_id, shard_files, current_epoch)) # pylint: disable=W0212
return shard_files, size_per_shard
def download_work(shard_id, current_idx, local_path, cache, q):
""" daemon process in backend. """
from .config_loader import config
from .util import try_load_from_obs, get_used_disk_per
while True:
idx, dataset_file = q.get()
used_disk = get_used_disk_per()
while used_disk > config.DISK_THRESHOLD:
logger.info("[{} FUNCTION] Used disk space is {}%, and the disk threshold is {}%.".format(
sys._getframe().f_code.co_name, used_disk*100, config.DISK_THRESHOLD*100)) # pylint: disable=W0212
retry_cnt = 0
has_deleted = delete_candidate_datasets(
current_idx.value, idx, cache, q, local_path)
while not has_deleted:
if retry_cnt > config.MAX_RETRY:
logger.warning("Delete operation retries times {} has exceeded threshold {}, "
"please clear enough disk space.".format(retry_cnt, config.MAX_RETRY))
has_deleted = delete_candidate_datasets(
current_idx.value, idx, cache, q, local_path)
retry_cnt += 1
time.sleep(config.RETRY_DELTA_TIME)
used_disk = get_used_disk_per()
logger.info("[{} FUNCTION] Shard: {} try to download: {}.".format(
sys._getframe().f_code.co_name, shard_id, dataset_file)) # pylint: disable=W0212
# update cache
remote_path = os.path.dirname(dataset_file)
dataset_file = os.path.basename(dataset_file)
_, is_shared = cache[dataset_file]
try_load_from_obs(remote_path, dataset_file, local_path, shard_id)
cache[dataset_file] = (idx, is_shared)
logger.info("[{} FUNCTION] Shard: {} finish to download: {}.".format(
sys._getframe().f_code.co_name, shard_id, dataset_file)) # pylint: disable=W0212
def delete_candidate_datasets(current_idx, queue_top_idx, cache, q, local_path):
"""
1. Try to delete all the datasets which have been loaded during the epoch.
2. Otherwise, try to delete a low priority dataset in the epoch.
3. As soon as the low priority data is deleted, it is placed in the download queue.
"""
used_datasets = []
low_priority_dataset = ''
max_idx = -1
delete = False
for k, v in cache.items():
idx, is_shared = v
if is_shared is False and idx >= 0:
if idx > max_idx:
max_idx = idx
low_priority_dataset = k
if idx < current_idx:
used_datasets.append(k)
for used_dataset in used_datasets:
dataset_path = os.path.join(local_path, used_dataset)
if not os.path.exists(dataset_path):
continue
# update cache
idx, is_shared = cache[used_dataset]
cache[used_dataset] = (-1, is_shared)
os.remove(dataset_path)
delete = True
logger.info("[{} FUNCTION] Delete used dataset file: {} and update the cache.".format(
sys._getframe().f_code.co_name, used_dataset)) # pylint: disable=W0212
if delete:
return True
if max_idx <= current_idx or max_idx <= queue_top_idx:
return False
dataset_path = os.path.join(local_path, low_priority_dataset)
if not os.path.exists(dataset_path):
return False
# update cache
idx, is_shared = cache[low_priority_dataset]
cache[low_priority_dataset] = (-1, is_shared)
os.remove(dataset_path)
q.put((idx, low_priority_dataset))
logger.info("[{} FUNCTION] Delete low priority dataset file: {} and update the cache.".format(
sys._getframe().f_code.co_name, low_priority_dataset)) # pylint: disable=W0212
return True
def _sync_up_for_obs_mindrecord_dataset(rank_id, current_epoch):
""" Upload the synchronization file to OBS. """
from .config_loader import config
from .util import file_upload_to_obs
sync_info = "download_dataset"
job_id = os.environ.get('BATCH_JOB_ID', 'unknown')
ready_file_name = sync_info + '_ready_' + str(rank_id) + '.txt'
ready_dir = os.path.join(job_id, str(current_epoch) + "/")
file_upload_to_obs(config.SYNC_OBS_PATH, ready_dir, ready_file_name)
logger.info("[{} FUNCTION] Current rank:{}'s sync file:{} is ready for epoch:{}.".format(
sys._getframe().f_code.co_name, rank_id, os.path.join(ready_dir, ready_file_name), current_epoch)) # pylint: disable=W0212
def sync_wait_for_dataset(rank_id, rank_size, current_epoch):
"""
Wait util the dataset files required by all devices are downloaded.
Note:
It should be used together with `mindspore.dataset.OBSMindDataset` and
be called before each epoch.
Args:
rank_id(int): Rank ID of the device.
rank_size(int): Rank size.
current_epoch(int): Number of current epochs.
Examples:
>>> # Create a synchronization callback
>>>
>>> from mindspore.dataset import sync_wait_for_dataset
>>> from mindspore.train.callback import Callback
>>>
>>> class SyncForDataset(Callback):
... def __init__(self):
... super(SyncForDataset, self).__init__()
... def epoch_begin(self, run_context):
... cb_params = run_context.original_args()
... epoch_num = cb_params.cur_epoch_num
... sync_wait_for_dataset(rank_id, rank_size, epoch_num)
"""
from .config_loader import config
from .util import obsClient, get_bucket_and_key
bucket_name, object_key = get_bucket_and_key(config.SYNC_OBS_PATH)
job_id = os.environ.get('BATCH_JOB_ID', 'unknown')
ready_dir = os.path.join(object_key, job_id, str(current_epoch) + "/")
success = False
while True:
if success:
break
try:
# no guarantee that the dir is included.
resp = obsClient.listObjects(bucket_name, prefix=ready_dir)
if resp.status < 300:
ready_num = 0
for content in resp.body.contents:
if content.key.endswith(".txt"):
ready_num += 1
if ready_num >= rank_size:
success = True
else:
logger.warning("[{} FUNCTION] OBS SDK errorCode:{}, errMsg: {}.".format(
sys._getframe(), resp.errorCode, resp.errorMessage)) # pylint: disable=W0212
except Exception: # pylint: disable=W0703
import traceback
logger.error(traceback.format_exc())
time.sleep(config.RETRY_DELTA_TIME)
logger.info("[{} FUNCTION] Waiting for sync dir:{} and current_rank:{}, total_rank:{}, "
"ready_rank:{} in epoch:{}.".format(sys._getframe().f_code.co_name, ready_dir, # pylint: disable=W0212
rank_id, rank_size, ready_num, current_epoch))
logger.info("[{} FUNCTION] Succeed to sync dir:{} and begin epoch:{}.".format(
sys._getframe().f_code.co_name, ready_dir, current_epoch)) # pylint: disable=W0212
def _sync_for_obs_mindrecord_dataset(shard_files, cache, num_shards, shard_id, current_epoch):
""" Synchronize all shardings. """
from .config_loader import config
while True:
dataset, _, _, _ = shard_files[-1]
current_dataset = os.path.basename(dataset)
hit_cache = cache[current_dataset][0]
if hit_cache >= 0: # hit cache
logger.info("[{} FUNCTION] Current_rank:{} has download:{} for epoch:{}.".format(
sys._getframe().f_code.co_name, shard_id, dataset, current_epoch)) # pylint: disable=W0212
_sync_up_for_obs_mindrecord_dataset(shard_id, current_epoch)
break
time.sleep(config.WARMINGUP_TIME)
logger.info("[{} FUNCTION] Current_rank:{} wait for downloading:{} in epoch:{}.".format(
sys._getframe().f_code.co_name, shard_id, dataset, current_epoch)) # pylint: disable=W0212
sync_wait_for_dataset(shard_id, num_shards, current_epoch)
class MindRecordFromOBS:
""" Internal class which load remote dataset files from OBS. """
def __init__(self, dataset_files, columns_list, shuffle, num_shards, shard_id, shard_equal_rows, local_path):
self._dataset_files = dataset_files
self._columns_list = columns_list
self._num_shards = num_shards
self._shard_id = shard_id
self._shard_equal_rows = shard_equal_rows
self._local_path = os.path.realpath(local_path)
self._shuffle = Shuffle.GLOBAL if shuffle is None or shuffle is True else shuffle
from .config_loader import config
self._seed = config.SEED
self._size_per_shard = None
self._curr_epoch = 1
self._curr_step = 1
self._shard_files, self._size_per_shard = init_shard_files(self._dataset_files, self._shuffle, self._seed,
self._num_shards, self._shard_id,
self._shard_equal_rows, self._size_per_shard,
self._local_path, self._curr_epoch)
m = get_manager()
self._queue = m.PriorityQueue()
self._cache = m.dict()
self._index = 0
self._current_idx = m.Value('i', self._index)
self._cache, self._queue = init_cache_and_working_queue(
self._cache, self._queue, self._shard_files, self._local_path)
self._index = 0
self._first_epoch = True
self._iteration = None
self._cache_miss_times = 0
self._pool = ThreadPool(processes=1)
self._pool.apply_async(download_work, (self._shard_id,
self._current_idx, self._local_path, self._cache, self._queue))
_sync_for_obs_mindrecord_dataset(
self._shard_files, self._cache, self._num_shards, self._shard_id, self._curr_epoch)
def get_col_names(self):
""" Get column names of Mindrecord format dataset."""
from ..datasets_standard_format import MindDataset
target_dataset = None
while target_dataset is None:
for f, _, _, _ in self._shard_files:
current_dataset = os.path.basename(f)
if self._cache[current_dataset][0] >= 0:
target_dataset = current_dataset
path = os.path.join(self._local_path, target_dataset)
_iteration = MindDataset(dataset_files=[path], shuffle=False)
return _iteration.get_col_names()
def __next__(self):
from .config_loader import config
from ..datasets_standard_format import MindDataset
from .util import make_sampler
if self._iteration:
try:
self._curr_step += 1
return next(self._iteration)
except StopIteration:
self._index += 1
self._current_idx.value = self._index
self._iteration = None
if self._index >= len(self._shard_files):
self._first_epoch = False
self._curr_epoch += 1
self._curr_step = 0
raise StopIteration
return next(self)
else:
f, start, end, is_full_dataset = self._shard_files[self._index]
current_dataset = os.path.basename(f)
hit_cache = self._cache[current_dataset][0]
if hit_cache >= 0: # hit cache
self._cache_miss_times = 0
# launch pipeline
sampler = make_sampler(
self._shuffle, is_full_dataset, start, end)
path = os.path.join(self._local_path, current_dataset)
logger.info("[{} FUNCTION] Shard:{} start to load dataset:{} in epoch:{}.".format(
sys._getframe().f_code.co_name, self._shard_id, path, self._curr_epoch)) # pylint: disable=W0212
self._iteration = MindDataset(dataset_files=[path], columns_list=self._columns_list, sampler=sampler,
shuffle=None).create_tuple_iterator(num_epochs=1, output_numpy=True)
else:
# cache miss
self._cache_miss_times += 1
logger.info("[{} FUNCTION] Cache miss in shard {} for times {}, expect dataset {}.".format(
sys._getframe().f_code.co_name, self._shard_id, self._cache_miss_times, current_dataset)) # pylint: disable=W0212
time.sleep(self._cache_miss_times * config.WAIT_STEP_TIME)
return next(self)
def __iter__(self):
if self._first_epoch:
self._index = 0
self._current_idx.value = self._index
self._iteration = None
return self
self._index = 0
self._current_idx.value = self._index
self._seed += 1
self._iteration = None
self._shard_files, self._size_per_shard = init_shard_files(self._dataset_files, self._shuffle, self._seed,
self._num_shards, self._shard_id,
self._shard_equal_rows, self._size_per_shard,
self._local_path, self._curr_epoch)
self._cache.clear()
# reset queue
try:
while True:
self._queue.get_nowait()
except queue.Empty:
pass
remove_unused_dataset(
self._local_path, self._num_shards, self._shard_id, self._curr_epoch)
wait_remove_datset(self._num_shards, self._shard_id, self._curr_epoch)
self._cache, self._queue = init_cache_and_working_queue(
self._cache, self._queue, self._shard_files, self._local_path)
_sync_for_obs_mindrecord_dataset(
self._shard_files, self._cache, self._num_shards, self._shard_id, self._curr_epoch)
return self
def __len__(self):
from .util import fetch_meta_files, make_dataset_tuple
if self._size_per_shard is not None:
return self._size_per_shard
dataset_files = []
for dataset_file, _, _, _ in self._shard_files:
dataset_files.append(dataset_file)
fetch_meta_files(dataset_files, self._local_path, self._shard_id)
self._size_per_shard, _ = make_dataset_tuple(
dataset_files, self._local_path)
return len(self)

View File

@ -0,0 +1,377 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This dataset module provides internal utility function for OBSMindDataset API.
"""
import fcntl
import os
import shutil
import sys
import sqlite3
import time
from functools import wraps
from obs import ObsClient
from mindspore import log as logger
from .config_loader import config
from ..datasets import Shuffle
from ..samplers import RandomSampler, SequentialSampler, SubsetSampler, SubsetRandomSampler
obsClient = ObsClient(
access_key_id=config.AK,
secret_access_key=config.SK,
server=config.SERVER
)
def get_used_disk_per():
""" Get the disk usage of working directory."""
if not os.path.exists(config.WORKING_PATH):
try:
os.makedirs(config.WORKING_PATH)
except FileExistsError:
pass
total, used, _ = shutil.disk_usage(config.WORKING_PATH)
return used / total
def try_load_from_obs(remote_path, dataset_file, local_path, shard_id):
""" Download all dataset files from obs, skip if it exists. """
try:
if not os.path.exists(os.path.join(local_path, dataset_file)):
download_file(remote_path, dataset_file, local_path, lock_file=dataset_file)
meta_file = dataset_file + '.db'
if not os.path.exists(os.path.join(local_path, meta_file)):
download_file(remote_path, meta_file, local_path, lock_file=meta_file)
except Exception as e:
raise RuntimeError("Failed to fetch file from obs, error: " + str(e))
def detect_all_meta_files(dataset_files, local_path):
""" Checking that all meta files exit on local. """
all_meta_files = True
for f in dataset_files:
dataset_file = os.path.basename(f)
meta_file = dataset_file + '.db'
if detect_file_exist(local_path, meta_file, lock_file=meta_file) is False:
all_meta_files = False
break
return all_meta_files
def make_sampler(shuffle, is_full_dataset, start, end):
""" Generate a proper sampler based on inputs. """
sampler = None
if shuffle == Shuffle.GLOBAL:
if is_full_dataset:
sampler = RandomSampler()
else:
sampler = SubsetRandomSampler(list(range(start, end)))
else:
if is_full_dataset:
sampler = SequentialSampler()
else:
sampler = SubsetSampler(list(range(start, end)))
return sampler
def make_shard_samples(dataset_file_size_list, size_per_shard, shard_id):
""" Make sharding files when shard_equal_rows is True. """
pre_cnt = 0
shard_files = []
finish = False
while finish is False:
for f, dataset_size in dataset_file_size_list:
start_idx = shard_id * size_per_shard
end_idx = (shard_id + 1) * size_per_shard
push = False
is_full_dataset = False
if pre_cnt <= start_idx < pre_cnt + dataset_size:
start = start_idx - pre_cnt
push = True
if pre_cnt < end_idx <= pre_cnt + dataset_size:
end = end_idx - pre_cnt
else:
end = dataset_size
if start_idx <= pre_cnt < end_idx:
start = 0
push = True
if pre_cnt + dataset_size >= end_idx:
end = end_idx - pre_cnt
else:
end = dataset_size
if push:
if start == 0 and end == dataset_size:
is_full_dataset = True
shard_files.append((f, start, end, is_full_dataset))
pre_cnt += dataset_size
if pre_cnt >= (shard_id + 1) * size_per_shard:
finish = True
return shard_files
def make_dataset_tuple(dataset_files, local_path):
""" Calculates the total size of the dataset and the size of each dataset file """
dataset_file_size_list = []
dataset_size = 0
for dataset_file in dataset_files:
meta_file = os.path.basename(dataset_file) + '.db'
path = os.path.join(local_path, meta_file)
try:
conn = sqlite3.connect(path)
c = conn.cursor()
cursor = c.execute("SELECT COUNT(*) FROM INDEXES")
for row in cursor:
dataset_size += row[0]
dataset_file_size_list.append((dataset_file, row[0]))
conn.close()
except Exception as e:
raise RuntimeError(
"Failed to get dataset size from metadata, err: " + str(e))
return dataset_size, dataset_file_size_list
def fetch_meta_files(dataset_files, local_path, shard_id):
""" Download all meta files from obs, skip if it exists"""
try:
for df in dataset_files:
dataset_file = os.path.basename(df)
meta_file = dataset_file + '.db'
remote_path = os.path.dirname(df)
download_file(remote_path, meta_file, local_path, lock_file=meta_file)
except Exception as e:
raise RuntimeError(
"Failed to fetch meta file from OBS, error: " + str(e))
def make_shard_files(dataset_files, num_shards, shard_id):
""" Make sharding files when shard_equal_rows is False. """
idx = 0
shard_files = []
for dataset_file in dataset_files:
if idx % num_shards == shard_id:
shard_files.append((dataset_file, -1, -1, True))
idx += 1
return shard_files
def get_bucket_and_key(obs_path):
r"""
split obs path to bucket name and object key.
Args:
obs_path: obs path that starts with s3://.
Returns:
(str, str), bucketName and objectKey.
"""
start = obs_path.find('//')
end = obs_path.find('/', start + 2)
if end == -1:
return obs_path[start + 2:], ""
return obs_path[start + 2:end], obs_path[end + 1:]
def exclusive_lock(func):
""" Decorator that execute func under exclusive lock. """
@wraps(func)
def wrapped_func(*args, **kwargs):
try:
lock_file = os.path.join('/tmp/', '{}.lock'.format(kwargs['lock_file']))
except KeyError:
raise RuntimeError("Lock file can not found in function {}.".format(func_name))
with open(lock_file, 'w') as fd:
retry_cnt = 0
success = False
while True:
if success:
break
try:
if retry_cnt > config.MAX_RETRY:
raise RuntimeError("Function {} retries times {} has exceeded threshold {}.".format(
func_name, retry_cnt, config.MAX_RETRY))
fcntl.flock(fd, fcntl.LOCK_EX)
success = True
result = func(*args, **kwargs)
except RuntimeError as e:
raise e
except Exception as e: # pylint: disable=W0703
retry_cnt += 1
import traceback
logger.error(traceback.format_exc())
time.sleep(config.RETRY_DELTA_TIME)
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
return result
return wrapped_func
def retry_execute(func):
""" Decorator that retry on unexpected errors. """
func_name = func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
retry_cnt = 0
success = False
while True:
if success:
break
try:
if retry_cnt > config.MAX_RETRY:
err_msg = "Function {} retries times {} has exceeded threshold {}.".format(
func_name, retry_cnt, config.MAX_RETRY)
logger.error(err_msg)
raise RuntimeError(err_msg)
result = func(*args, **kwargs)
success = True
except RuntimeError as e:
raise e
except Exception: # pylint: disable=W0703
retry_cnt += 1
import traceback
logger.error(traceback.format_exc())
time.sleep(config.RETRY_DELTA_TIME)
return result
return wrapper
@retry_execute
def check_file_exists_in_obs(obs_path):
""" Detect that file exists in obs. """
bucket_name, object_key = get_bucket_and_key(obs_path)
resp = obsClient.getObjectMetadata(bucket_name, object_key)
if resp.status < 300:
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
sys._getframe(), resp.requestId)) # pylint: disable=W0212
else:
err_msg = "File {} not found in OBS, please check again.".format(obs_path)
logger.error(err_msg)
raise RuntimeError(err_msg)
@retry_execute
def file_download_from_obs(obs_path, local_path):
""" Download file from OBS. """
bucket_name, object_key = get_bucket_and_key(obs_path)
downloadFile = local_path
taskNum = config.TASK_NUM
partSize = config.PART_SIZE
enableCheckpoint = True
resp = obsClient.downloadFile(
bucket_name, object_key, downloadFile, partSize, taskNum, enableCheckpoint)
if resp.status < 300:
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
sys._getframe(), resp.requestId)) # pylint: disable=W0212
else:
raise Exception("OBS SDK errorCode:{}, errMsg: {}.".format(
resp.errorCode, resp.errorMessage))
@exclusive_lock
def download_file(remote_path, object_name, des_path, lock_file='tmp'):
""" Download file from OBS exclusively. """
local_path = os.path.join(des_path, object_name)
if os.path.exists(local_path):
return
if not os.path.exists(des_path):
os.makedirs(des_path)
obs_path = os.path.join(remote_path, object_name)
check_file_exists_in_obs(obs_path)
file_download_from_obs(obs_path, local_path)
@exclusive_lock
def init_cache_and_queue(cache, q, path, shard_file, idx, is_full_dataset, lock_file='tmp'):
""" Initialize cache and queue according to the status of local dataset files."""
dataset_file = os.path.basename(shard_file)
if os.path.exists(path): # found in local
logger.info("[{} FUNCTION] Push dataset file {} to cache.".format(
sys._getframe(), dataset_file)) # pylint: disable=W0212
cache[dataset_file] = (idx, not is_full_dataset)
else:
logger.info("[{} FUNCTION] Push dataset file {} to downloading queue.".format(
sys._getframe(), dataset_file)) # pylint: disable=W0212
cache[dataset_file] = (-1, not is_full_dataset)
q.put((idx, shard_file))
@exclusive_lock
def detect_file_exist(local_path, meta_file, lock_file='tmp'):
""" Detect that local dataset file exists or not. """
if os.path.exists(os.path.join(local_path, meta_file)):
return True
return False
@retry_execute
def file_upload_to_obs(obs_path, sync_dir, ready_file_name):
""" Upload sync file to OBS. """
bucket_name, object_key = get_bucket_and_key(obs_path)
if not object_key:
resp = obsClient.headBucket(bucket_name)
else:
if not object_key.endswith("/"):
object_key += "/"
resp = obsClient.getObjectMetadata(bucket_name, object_key)
if resp.status < 300:
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
sys._getframe(), resp.requestId)) # pylint: disable=W0212
else:
raise RuntimeError("Directory/Bucket used for synchronization {} is not found in OBS, " \
"please create it on OBS first.".format(obs_path))
remote_dir = os.path.join(object_key, sync_dir)
resp = obsClient.putContent(bucket_name, remote_dir, content=None)
if resp.status < 300:
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
sys._getframe(), resp.requestId)) # pylint: disable=W0212
else:
raise Exception("OBS SDK errorCode:{}, errMsg: {}.".format(
resp.errorCode, resp.errorMessage))
resp = obsClient.putContent(bucket_name, os.path.join(
remote_dir, ready_file_name), content='OK')
if resp.status < 300:
logger.debug("[{} FUNCTION] OBS requestId: {}.".format(
sys._getframe(), resp.requestId)) # pylint: disable=W0212
else:
raise Exception("OBS SDK errorCode:{}, errMsg: {}.".format(
resp.errorCode, resp.errorMessage))

View File

@ -2814,3 +2814,36 @@ def check_multi30k_dataset(method):
return method(self, *args, **kwargs)
return new_method
def check_obsminddataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(OBSMindDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_shards', 'shard_id']
nreq_param_list = ['columns_list']
nreq_param_bool = ['shard_equal_rows']
nreq_param_str = ['server', 'ak', 'sk', 'sync_obs_path']
dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (list,), "dataset files")
for dataset_file in dataset_files:
if not isinstance(dataset_file, str):
raise TypeError("Item of dataset files is not of type [{}], but got {}.".format(type(''),
type(dataset_file)))
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_list, param_dict, list)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_str, param_dict, str)
server = param_dict.get('server')
if not server.startswith(('http://', 'https://')):
raise ValueError("server should be a str that starts with http:// or https://, but got {}.".format(server))
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method

View File

@ -0,0 +1,80 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test OBSMindDataset operator
"""
import pytest
from mindspore.dataset.engine.datasets_standard_format import OBSMindDataset
from mindspore import log as logger
DATA_DIR = ["s3://dataset/imagenet0", "s3://dataset/imagenet1"]
def test_obs_mindrecord_exception():
"""
Feature: Test OBSMindDataset.
Description: invalid input.
Expectation: raise exception.
"""
logger.info("Test error cases for MnistDataset")
error_msg_0 = "Argument dataset files"
with pytest.raises(TypeError, match=error_msg_0):
OBSMindDataset("err_dataset", "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
error_msg_0_1 = "Item of dataset files"
with pytest.raises(TypeError, match=error_msg_0_1):
OBSMindDataset([1, 2], "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
error_msg_1 = "Argument server"
with pytest.raises(TypeError, match=error_msg_1):
OBSMindDataset(DATA_DIR, 12, "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
error_msg_1_1 = "server should"
with pytest.raises(ValueError, match=error_msg_1_1):
OBSMindDataset(DATA_DIR, "ftp://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir")
error_msg_2 = "Argument ak"
with pytest.raises(TypeError, match=error_msg_2):
OBSMindDataset(DATA_DIR, "https://dummy_site", 12, "dummy_sk", "s3://dummy_sync_dir")
error_msg_3 = "Argument sk"
with pytest.raises(TypeError, match=error_msg_3):
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", 12, "s3://dummy_sync_dir")
error_msg_4 = "Argument sync_obs_path"
with pytest.raises(TypeError, match=error_msg_4):
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", "dummy_sk", 12)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
"dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
"dummy_sk", "s3://dummy_sync_dir", num_shards=4, shard_id=4)
with pytest.raises(ValueError, match=error_msg_5):
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
"dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=4)
error_msg_7 = "Argument shard_equal_rows"
with pytest.raises(TypeError, match=error_msg_7):
OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak",
"dummy_sk", "s3://dummy_sync_dir", shard_equal_rows=1)
if __name__ == '__main__':
test_obs_mindrecord_exception()