forked from mindspore-Ecosystem/mindspore
!19765 Fix docs issues
Merge pull request !19765 from luoyang/code_docs_python-doc
This commit is contained in:
commit
f86644358e
|
@ -117,7 +117,6 @@ class BertTokenizer final : public TensorTransform {
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Apply case fold operation on UTF-8 string tensors.
|
/// \brief Apply case fold operation on UTF-8 string tensors.
|
||||||
/// \return Shared pointer to the current TensorOperation.
|
|
||||||
class CaseFold final : public TensorTransform {
|
class CaseFold final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
|
@ -142,7 +141,8 @@ class JiebaTokenizer final : public TensorTransform {
|
||||||
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
|
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
|
||||||
/// \param[in] mp_path Dictionary file is used by the MPSegment algorithm. The dictionary can be obtained on the
|
/// \param[in] mp_path Dictionary file is used by the MPSegment algorithm. The dictionary can be obtained on the
|
||||||
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
|
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
|
||||||
/// \param[in] mode Valid values can be any of JiebaMode.MP, JiebaMode.HMM and JiebaMode.MIX (default=JiebaMode.MIX).
|
/// \param[in] mode Valid values can be any of JiebaMode.kMP, JiebaMode.kHMM and JiebaMode.kMIX
|
||||||
|
/// (default=JiebaMode.kMIX).
|
||||||
/// - JiebaMode.kMP, tokenizes with MPSegment algorithm.
|
/// - JiebaMode.kMP, tokenizes with MPSegment algorithm.
|
||||||
/// - JiebaMode.kHMM, tokenizes with Hidden Markov Model Segment algorithm.
|
/// - JiebaMode.kHMM, tokenizes with Hidden Markov Model Segment algorithm.
|
||||||
/// - JiebaMode.kMIX, tokenizes with a mix of MPSegment and HMMSegment algorithms.
|
/// - JiebaMode.kMIX, tokenizes with a mix of MPSegment and HMMSegment algorithms.
|
||||||
|
@ -248,7 +248,7 @@ class Ngram final : public TensorTransform {
|
||||||
/// \param[in] left_pad {"pad_token", pad_width}. Padding performed on left side of the sequence. pad_width will
|
/// \param[in] left_pad {"pad_token", pad_width}. Padding performed on left side of the sequence. pad_width will
|
||||||
/// be capped at n-1. left_pad=("_",2) would pad the left side of the sequence with "__" (default={"", 0}}).
|
/// be capped at n-1. left_pad=("_",2) would pad the left side of the sequence with "__" (default={"", 0}}).
|
||||||
/// \param[in] right_pad {"pad_token", pad_width}. Padding performed on right side of the sequence.pad_width will
|
/// \param[in] right_pad {"pad_token", pad_width}. Padding performed on right side of the sequence.pad_width will
|
||||||
/// be capped at n-1. right_pad=("-":2) would pad the right side of the sequence with "--" (default={"", 0}}).
|
/// be capped at n-1. right_pad=("-",2) would pad the right side of the sequence with "--" (default={"", 0}}).
|
||||||
/// \param[in] separator Symbol used to join strings together (default=" ").
|
/// \param[in] separator Symbol used to join strings together (default=" ").
|
||||||
explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0},
|
explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0},
|
||||||
const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ")
|
const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ")
|
||||||
|
@ -276,14 +276,13 @@ class NormalizeUTF8 final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
/// \param[in] normalize_form Valid values can be any of [NormalizeForm::kNone,NormalizeForm::kNfc,
|
/// \param[in] normalize_form Valid values can be any of [NormalizeForm::kNone,NormalizeForm::kNfc,
|
||||||
/// NormalizeForm::kNfkc,
|
/// NormalizeForm::kNfkc, NormalizeForm::kNfd, NormalizeForm::kNfkd](default=NormalizeForm::kNfkc).
|
||||||
/// NormalizeForm::kNfd, NormalizeForm::kNfkd](default=NormalizeForm::kNfkc).
|
|
||||||
/// See http://unicode.org/reports/tr15/ for details.
|
/// See http://unicode.org/reports/tr15/ for details.
|
||||||
/// - NormalizeForm.NONE, remain the input string tensor unchanged.
|
/// - NormalizeForm.kNone, remain the input string tensor unchanged.
|
||||||
/// - NormalizeForm.NFC, normalizes with Normalization Form C.
|
/// - NormalizeForm.kNfc, normalizes with Normalization Form C.
|
||||||
/// - NormalizeForm.NFKC, normalizes with Normalization Form KC.
|
/// - NormalizeForm.kNfkc, normalizes with Normalization Form KC.
|
||||||
/// - NormalizeForm.NFD, normalizes with Normalization Form D.
|
/// - NormalizeForm.kNfd, normalizes with Normalization Form D.
|
||||||
/// - NormalizeForm.NFKD, normalizes with Normalization Form KD.
|
/// - NormalizeForm.kNfkd, normalizes with Normalization Form KD.
|
||||||
explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc);
|
explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
|
|
|
@ -79,7 +79,6 @@ def set_seed(seed):
|
||||||
If the seed is set, the generated random number will be fixed, this helps to
|
If the seed is set, the generated random number will be fixed, this helps to
|
||||||
produce deterministic results.
|
produce deterministic results.
|
||||||
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This set_seed function sets the seed in the Python random library and numpy.random library
|
This set_seed function sets the seed in the Python random library and numpy.random library
|
||||||
for deterministic Python augmentations using randomness. This set_seed function should
|
for deterministic Python augmentations using randomness. This set_seed function should
|
||||||
|
@ -113,6 +112,11 @@ def get_seed():
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, random number seed.
|
int, random number seed.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the global configuration of seed.
|
||||||
|
>>> # If set_seed() is never called before, the default value(std::mt19937::default_seed) will be returned.
|
||||||
|
>>> seed = ds.config.get_seed()
|
||||||
"""
|
"""
|
||||||
return _config.get_seed()
|
return _config.get_seed()
|
||||||
|
|
||||||
|
@ -147,6 +151,11 @@ def get_prefetch_size():
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, total number of rows to be prefetched.
|
int, total number of rows to be prefetched.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the global configuration of prefetch size.
|
||||||
|
>>> # If set_prefetch_size() is never called before, the default value(20) will be returned.
|
||||||
|
>>> prefetch_size = ds.config.get_prefetch_size()
|
||||||
"""
|
"""
|
||||||
return _config.get_op_connector_size()
|
return _config.get_op_connector_size()
|
||||||
|
|
||||||
|
@ -174,12 +183,17 @@ def set_num_parallel_workers(num):
|
||||||
|
|
||||||
def get_num_parallel_workers():
|
def get_num_parallel_workers():
|
||||||
"""
|
"""
|
||||||
Get the default number of parallel workers.
|
Get the global configuration of number of parallel workers.
|
||||||
This is the DEFAULT num_parallel_workers value used for each operation, it is not related
|
This is the DEFAULT num_parallel_workers value used for each operation, it is not related
|
||||||
to AutoNumWorker feature.
|
to AutoNumWorker feature.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, number of parallel workers to be used as a default for each operation.
|
int, number of parallel workers to be used as a default for each operation.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the global configuration of parallel workers.
|
||||||
|
>>> # If set_num_parallel_workers() is never called before, the default value(8) will be returned.
|
||||||
|
>>> num_parallel_workers = ds.config.get_num_parallel_workers()
|
||||||
"""
|
"""
|
||||||
return _config.get_num_parallel_workers()
|
return _config.get_num_parallel_workers()
|
||||||
|
|
||||||
|
@ -206,11 +220,15 @@ def set_numa_enable(numa_enable):
|
||||||
|
|
||||||
def get_numa_enable():
|
def get_numa_enable():
|
||||||
"""
|
"""
|
||||||
Get the default state of numa enabled.
|
Get the state of numa to indicate enabled/disabled.
|
||||||
This is the DEFAULT numa enabled value used for the all process.
|
This is the DEFAULT numa enabled value used for the all process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool, the default state of numa enabled.
|
bool, the default state of numa enabled.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the global configuration of numa.
|
||||||
|
>>> numa_state = ds.config.get_numa_enable()
|
||||||
"""
|
"""
|
||||||
return _config.get_numa_enable()
|
return _config.get_numa_enable()
|
||||||
|
|
||||||
|
@ -236,10 +254,15 @@ def set_monitor_sampling_interval(interval):
|
||||||
|
|
||||||
def get_monitor_sampling_interval():
|
def get_monitor_sampling_interval():
|
||||||
"""
|
"""
|
||||||
Get the default interval of performance monitor sampling.
|
Get the global configuration of sampling interval of performance monitor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, interval (in milliseconds) for performance monitor sampling.
|
int, interval (in milliseconds) for performance monitor sampling.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the global configuration of monitor sampling interval.
|
||||||
|
>>> # If set_monitor_sampling_interval() is never called before, the default value(1000) will be returned.
|
||||||
|
>>> ds.config.get_monitor_sampling_interval()
|
||||||
"""
|
"""
|
||||||
return _config.get_monitor_sampling_interval()
|
return _config.get_monitor_sampling_interval()
|
||||||
|
|
||||||
|
@ -299,9 +322,10 @@ def get_auto_num_workers():
|
||||||
Get the setting (turned on or off) automatic number of workers.
|
Get the setting (turned on or off) automatic number of workers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool, whether auto num worker feature is turned on.
|
bool, whether auto number worker feature is turned on.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Get the global configuration of auto number worker feature.
|
||||||
>>> num_workers = ds.config.get_auto_num_workers()
|
>>> num_workers = ds.config.get_auto_num_workers()
|
||||||
"""
|
"""
|
||||||
return _config.get_auto_num_workers()
|
return _config.get_auto_num_workers()
|
||||||
|
@ -334,6 +358,11 @@ def get_callback_timeout():
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the global configuration of callback timeout.
|
||||||
|
>>> # If set_callback_timeout() is never called before, the default value(60) will be returned.
|
||||||
|
>>> ds.config.get_callback_timeout()
|
||||||
"""
|
"""
|
||||||
return _config.get_callback_timeout()
|
return _config.get_callback_timeout()
|
||||||
|
|
||||||
|
@ -394,6 +423,10 @@ def get_enable_shared_mem():
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool, the state of shared mem enabled variable (default=True).
|
bool, the state of shared mem enabled variable (default=True).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get the flag of shared memory feature.
|
||||||
|
>>> shared_mem_flag = ds.config.get_enable_shared_mem()
|
||||||
"""
|
"""
|
||||||
return _config.get_enable_shared_mem()
|
return _config.get_enable_shared_mem()
|
||||||
|
|
||||||
|
@ -410,12 +443,14 @@ def set_enable_shared_mem(enable):
|
||||||
TypeError: If enable is not a boolean data type.
|
TypeError: If enable is not a boolean data type.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
>>> # Enable shared memory feature to improve the performance of Python multiprocessing.
|
||||||
>>> ds.config.set_enable_shared_mem(True)
|
>>> ds.config.set_enable_shared_mem(True)
|
||||||
"""
|
"""
|
||||||
if not isinstance(enable, bool):
|
if not isinstance(enable, bool):
|
||||||
raise TypeError("enable must be of type bool.")
|
raise TypeError("enable must be of type bool.")
|
||||||
_config.set_enable_shared_mem(enable)
|
_config.set_enable_shared_mem(enable)
|
||||||
|
|
||||||
|
|
||||||
def set_sending_batches(batch_num):
|
def set_sending_batches(batch_num):
|
||||||
"""
|
"""
|
||||||
Set the default sending batches when training with sink_mode=True in Ascend device.
|
Set the default sending batches when training with sink_mode=True in Ascend device.
|
||||||
|
|
|
@ -334,7 +334,7 @@ class Dataset:
|
||||||
Serialize a pipeline into JSON string and dump into file if filename is provided.
|
Serialize a pipeline into JSON string and dump into file if filename is provided.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): filename of json file to be saved as
|
filename (str): filename of JSON file to be saved as.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str, JSON string of the pipeline.
|
str, JSON string of the pipeline.
|
||||||
|
@ -1511,7 +1511,7 @@ class Dataset:
|
||||||
|
|
||||||
def get_col_names(self):
|
def get_col_names(self):
|
||||||
"""
|
"""
|
||||||
Renturn the names of the columns in dataset.
|
Return the names of the columns in dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list, list of column names in the dataset.
|
list, list of column names in the dataset.
|
||||||
|
@ -1582,7 +1582,7 @@ class Dataset:
|
||||||
|
|
||||||
def dynamic_min_max_shapes(self):
|
def dynamic_min_max_shapes(self):
|
||||||
"""
|
"""
|
||||||
Get minimum and maximum data length of dynamic source data, for graph compilation of ME.
|
Get minimum and maximum data length of dynamic source data, for dynamic graph compilation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
lists, min_shapes, max_shapes of source data.
|
lists, min_shapes, max_shapes of source data.
|
||||||
|
@ -2187,7 +2187,7 @@ class BatchDataset(Dataset):
|
||||||
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list)
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list)
|
||||||
self.hook = _ExceptHookHandler()
|
self.hook = _ExceptHookHandler()
|
||||||
atexit.register(_mp_pool_exit_preprocess)
|
atexit.register(_mp_pool_exit_preprocess)
|
||||||
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
# If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
atexit.register(self.process_pool.close)
|
atexit.register(self.process_pool.close)
|
||||||
else:
|
else:
|
||||||
|
@ -2682,7 +2682,7 @@ class MapDataset(Dataset):
|
||||||
self.operations = iter_specific_operations
|
self.operations = iter_specific_operations
|
||||||
self.hook = _ExceptHookHandler()
|
self.hook = _ExceptHookHandler()
|
||||||
atexit.register(_mp_pool_exit_preprocess)
|
atexit.register(_mp_pool_exit_preprocess)
|
||||||
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
# If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
atexit.register(self.process_pool.close)
|
atexit.register(self.process_pool.close)
|
||||||
|
|
||||||
|
@ -4798,12 +4798,12 @@ class VOCDataset(MappableDataset):
|
||||||
title = {The Pascal Visual Object Classes (VOC) Challenge},
|
title = {The Pascal Visual Object Classes (VOC) Challenge},
|
||||||
journal = {International Journal of Computer Vision},
|
journal = {International Journal of Computer Vision},
|
||||||
volume = {88},
|
volume = {88},
|
||||||
year = {2010},
|
year = {2012},
|
||||||
number = {2},
|
number = {2},
|
||||||
month = {jun},
|
month = {jun},
|
||||||
pages = {303--338},
|
pages = {303--338},
|
||||||
biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
|
biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
|
||||||
howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc{year}/index.html}
|
howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -4959,10 +4959,11 @@ class CocoDataset(MappableDataset):
|
||||||
|
|
||||||
About COCO dataset:
|
About COCO dataset:
|
||||||
|
|
||||||
COCO is a large-scale object detection, segmentation, and captioning dataset.
|
COCO(Microsoft Common Objects in Context) is a large-scale object detection, segmentation, and captioning dataset
|
||||||
It contains 91 common object categories with 82 of them having more than 5,000
|
with several features: Object segmentation, Recognition in context, Superpixel stuff segmentation,
|
||||||
labeled instances. In contrast to the popular ImageNet dataset, COCO has fewer
|
330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories,
|
||||||
categories but more instances per category.
|
5 captions per image, 250,000 people with keypoints. In contrast to the popular ImageNet dataset, COCO has fewer
|
||||||
|
categories but more instances in per category.
|
||||||
|
|
||||||
You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API.
|
You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API.
|
||||||
|
|
||||||
|
@ -5304,7 +5305,7 @@ class CLUEDataset(SourceDataset):
|
||||||
|
|
||||||
About CLUE dataset:
|
About CLUE dataset:
|
||||||
|
|
||||||
CLUE, a Chinese Language Understanding Evaluation benchmark. It contains eight different
|
CLUE, a Chinese Language Understanding Evaluation benchmark. It contains multiple
|
||||||
tasks, including single-sentence classification, sentence pair classification, and machine
|
tasks, including single-sentence classification, sentence pair classification, and machine
|
||||||
reading comprehension.
|
reading comprehension.
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""
|
"""
|
||||||
This dataset module creates an internal queue class to more optimally pass data
|
This dataset module creates an internal queue class to more optimally pass data
|
||||||
between multiple processes in python. It has same API as multiprocessing.queue
|
between multiple processes in Python. It has same API as multiprocessing.queue
|
||||||
but it will pass large data through shared memory.
|
but it will pass large data through shared memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -121,7 +121,7 @@ class BuiltinSampler:
|
||||||
self.child_sampler = sampler
|
self.child_sampler = sampler
|
||||||
|
|
||||||
def get_child(self):
|
def get_child(self):
|
||||||
""" add a child sampler. """
|
""" Get the child sampler. """
|
||||||
return self.child_sampler
|
return self.child_sampler
|
||||||
|
|
||||||
def parse_child(self):
|
def parse_child(self):
|
||||||
|
@ -139,11 +139,11 @@ class BuiltinSampler:
|
||||||
return c_child_sampler
|
return c_child_sampler
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
""" not implemented """
|
""" Not implemented. """
|
||||||
raise NotImplementedError("Sampler must implement is_shuffled.")
|
raise NotImplementedError("Sampler must implement is_shuffled.")
|
||||||
|
|
||||||
def is_sharded(self):
|
def is_sharded(self):
|
||||||
""" not implemented """
|
""" Not implemented. """
|
||||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||||
|
|
||||||
def get_num_samples(self):
|
def get_num_samples(self):
|
||||||
|
@ -313,8 +313,10 @@ class DistributedSampler(BuiltinSampler):
|
||||||
shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
|
shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
|
||||||
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
|
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
|
||||||
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
|
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
|
||||||
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
|
offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
|
||||||
should be no more than num_shards.
|
should be no more than num_shards. This parameter is only valid when a ConcatDataset takes
|
||||||
|
a DistributedSampler as its sampler. It will affect the number of samples of per shard
|
||||||
|
(default=-1, which means each shard has same number of samples).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
|
>>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
|
||||||
|
@ -329,9 +331,9 @@ class DistributedSampler(BuiltinSampler):
|
||||||
TypeError: If shuffle is not a boolean value.
|
TypeError: If shuffle is not a boolean value.
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
TypeError: If offset is not an integer value.
|
TypeError: If offset is not an integer value.
|
||||||
|
ValueError: If num_samples is a negative value.
|
||||||
RuntimeError: If num_shards is not a positive value.
|
RuntimeError: If num_shards is not a positive value.
|
||||||
RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
||||||
RuntimeError: If num_samples is a negative value.
|
|
||||||
RuntimeError: If offset is greater than num_shards.
|
RuntimeError: If offset is greater than num_shards.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -423,13 +425,12 @@ class PKSampler(BuiltinSampler):
|
||||||
... sampler=sampler)
|
... sampler=sampler)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If num_val is not a positive value.
|
|
||||||
TypeError: If shuffle is not a boolean value.
|
TypeError: If shuffle is not a boolean value.
|
||||||
TypeError: If class_column is not a str value.
|
TypeError: If class_column is not a str value.
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
NotImplementedError: If num_class is not None.
|
NotImplementedError: If num_class is not None.
|
||||||
RuntimeError: If num_val is not a positive value.
|
RuntimeError: If num_val is not a positive value.
|
||||||
RuntimeError: If num_samples is a negative value.
|
ValueError: If num_samples is a negative value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
||||||
|
@ -508,7 +509,7 @@ class RandomSampler(BuiltinSampler):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If replacement is not a boolean value.
|
TypeError: If replacement is not a boolean value.
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
RuntimeError: If num_samples is a negative value.
|
ValueError: If num_samples is a negative value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, replacement=False, num_samples=None):
|
def __init__(self, replacement=False, num_samples=None):
|
||||||
|
@ -573,7 +574,7 @@ class SequentialSampler(BuiltinSampler):
|
||||||
TypeError: If start_index is not an integer value.
|
TypeError: If start_index is not an integer value.
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
RuntimeError: If start_index is a negative value.
|
RuntimeError: If start_index is a negative value.
|
||||||
RuntimeError: If num_samples is a negative value.
|
ValueError: If num_samples is a negative value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, start_index=None, num_samples=None):
|
def __init__(self, start_index=None, num_samples=None):
|
||||||
|
@ -641,7 +642,7 @@ class SubsetSampler(BuiltinSampler):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If type of indices element is not a number.
|
TypeError: If type of indices element is not a number.
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
RuntimeError: If num_samples is a negative value.
|
ValueError: If num_samples is a negative value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, indices, num_samples=None):
|
def __init__(self, indices, num_samples=None):
|
||||||
|
@ -713,7 +714,7 @@ class SubsetRandomSampler(SubsetSampler):
|
||||||
Samples the elements randomly from a sequence of indices.
|
Samples the elements randomly from a sequence of indices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
indices (Any iterable python object but string): A sequence of indices.
|
indices (Any iterable Python object but string): A sequence of indices.
|
||||||
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
|
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -726,7 +727,7 @@ class SubsetRandomSampler(SubsetSampler):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If type of indices element is not a number.
|
TypeError: If type of indices element is not a number.
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
RuntimeError: If num_samples is a negative value.
|
ValueError: If num_samples is a negative value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parse(self):
|
def parse(self):
|
||||||
|
@ -806,7 +807,7 @@ class WeightedRandomSampler(BuiltinSampler):
|
||||||
TypeError: If num_samples is not an integer value.
|
TypeError: If num_samples is not an integer value.
|
||||||
TypeError: If replacement is not a boolean value.
|
TypeError: If replacement is not a boolean value.
|
||||||
RuntimeError: If weights is empty or all zero.
|
RuntimeError: If weights is empty or all zero.
|
||||||
RuntimeError: If num_samples is a negative value.
|
ValueError: If num_samples is a negative value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weights, num_samples=None, replacement=True):
|
def __init__(self, weights, num_samples=None, replacement=True):
|
||||||
|
|
|
@ -27,15 +27,15 @@ from ..vision.utils import Inter, Border, ImageBatchFormat
|
||||||
|
|
||||||
def serialize(dataset, json_filepath=""):
|
def serialize(dataset, json_filepath=""):
|
||||||
"""
|
"""
|
||||||
Serialize dataset pipeline into a json file.
|
Serialize dataset pipeline into a JSON file.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Currently some python objects are not supported to be serialized.
|
Currently some Python objects are not supported to be serialized.
|
||||||
For python function serialization of map operator, de.serialize will only return its function name.
|
For Python function serialization of map operator, de.serialize will only return its function name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (Dataset): The starting node.
|
dataset (Dataset): The starting node.
|
||||||
json_filepath (str): The filepath where a serialized json file will be generated.
|
json_filepath (str): The filepath where a serialized JSON file will be generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict, The dictionary contains the serialized dataset graph.
|
Dict, The dictionary contains the serialized dataset graph.
|
||||||
|
@ -48,7 +48,7 @@ def serialize(dataset, json_filepath=""):
|
||||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
|
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
|
||||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||||
>>> # serialize it to json file
|
>>> # serialize it to JSON file
|
||||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||||
>>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict
|
>>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict
|
||||||
"""
|
"""
|
||||||
|
@ -57,27 +57,27 @@ def serialize(dataset, json_filepath=""):
|
||||||
|
|
||||||
def deserialize(input_dict=None, json_filepath=None):
|
def deserialize(input_dict=None, json_filepath=None):
|
||||||
"""
|
"""
|
||||||
Construct a de pipeline from a json file produced by de.serialize().
|
Construct a de pipeline from a JSON file produced by de.serialize().
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Currently python function deserialization of map operator are not supported.
|
Currently Python function deserialization of map operator are not supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dict (dict): A Python dictionary containing a serialized dataset graph.
|
input_dict (dict): A Python dictionary containing a serialized dataset graph.
|
||||||
json_filepath (str): A path to the json file.
|
json_filepath (str): A path to the JSON file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
de.Dataset or None if error occurs.
|
de.Dataset or None if error occurs.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
OSError: Can not open the json file.
|
OSError: Can not open the JSON file.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
|
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
|
||||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||||
>>> # Use case 1: to/from json file
|
>>> # Use case 1: to/from JSON file
|
||||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||||
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||||
>>> # Use case 2: to/from Python dictionary
|
>>> # Use case 2: to/from Python dictionary
|
||||||
|
@ -113,8 +113,15 @@ def show(dataset, indentation=2):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (Dataset): The starting node.
|
dataset (Dataset): The starting node.
|
||||||
indentation (int, optional): The indentation used by the json print.
|
indentation (int, optional): The indentation used by the JSON print.
|
||||||
Do not indent if indentation is None.
|
Do not indent if indentation is None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||||
|
>>> one_hot_encode = c_transforms.OneHot(10)
|
||||||
|
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||||
|
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||||
|
>>> ds.show(dataset)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pipeline = dataset.to_json()
|
pipeline = dataset.to_json()
|
||||||
|
@ -128,13 +135,21 @@ def compare(pipeline1, pipeline2):
|
||||||
Args:
|
Args:
|
||||||
pipeline1 (Dataset): a dataset pipeline.
|
pipeline1 (Dataset): a dataset pipeline.
|
||||||
pipeline2 (Dataset): a dataset pipeline.
|
pipeline2 (Dataset): a dataset pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether pipeline1 is equal to pipeline2.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||||
|
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
|
||||||
|
>>> ds.compare(pipeline1, pipeline2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return pipeline1.to_json() == pipeline2.to_json()
|
return pipeline1.to_json() == pipeline2.to_json()
|
||||||
|
|
||||||
|
|
||||||
def construct_pipeline(node):
|
def construct_pipeline(node):
|
||||||
"""Construct the Python Dataset objects by following the dictionary deserialized from json file."""
|
"""Construct the Python Dataset objects by following the dictionary deserialized from JSON file."""
|
||||||
op_type = node.get('op_type')
|
op_type = node.get('op_type')
|
||||||
if not op_type:
|
if not op_type:
|
||||||
raise ValueError("op_type field in the json file can't be None.")
|
raise ValueError("op_type field in the json file can't be None.")
|
||||||
|
|
Loading…
Reference in New Issue