!19765 Fix docs issues

Merge pull request !19765 from luoyang/code_docs_python-doc
This commit is contained in:
i-robot 2021-07-09 07:48:49 +00:00 committed by Gitee
commit f86644358e
6 changed files with 109 additions and 58 deletions

View File

@ -117,7 +117,6 @@ class BertTokenizer final : public TensorTransform {
}; };
/// \brief Apply case fold operation on UTF-8 string tensors. /// \brief Apply case fold operation on UTF-8 string tensors.
/// \return Shared pointer to the current TensorOperation.
class CaseFold final : public TensorTransform { class CaseFold final : public TensorTransform {
public: public:
/// \brief Constructor. /// \brief Constructor.
@ -142,7 +141,8 @@ class JiebaTokenizer final : public TensorTransform {
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba). /// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
/// \param[in] mp_path Dictionary file is used by the MPSegment algorithm. The dictionary can be obtained on the /// \param[in] mp_path Dictionary file is used by the MPSegment algorithm. The dictionary can be obtained on the
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba). /// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
/// \param[in] mode Valid values can be any of JiebaMode.MP, JiebaMode.HMM and JiebaMode.MIX (default=JiebaMode.MIX). /// \param[in] mode Valid values can be any of JiebaMode.kMP, JiebaMode.kHMM and JiebaMode.kMIX
/// (default=JiebaMode.kMIX).
/// - JiebaMode.kMP, tokenizes with MPSegment algorithm. /// - JiebaMode.kMP, tokenizes with MPSegment algorithm.
/// - JiebaMode.kHMM, tokenizes with Hidden Markov Model Segment algorithm. /// - JiebaMode.kHMM, tokenizes with Hidden Markov Model Segment algorithm.
/// - JiebaMode.kMIX, tokenizes with a mix of MPSegment and HMMSegment algorithms. /// - JiebaMode.kMIX, tokenizes with a mix of MPSegment and HMMSegment algorithms.
@ -248,7 +248,7 @@ class Ngram final : public TensorTransform {
/// \param[in] left_pad {"pad_token", pad_width}. Padding performed on left side of the sequence. pad_width will /// \param[in] left_pad {"pad_token", pad_width}. Padding performed on left side of the sequence. pad_width will
/// be capped at n-1. left_pad=("_",2) would pad the left side of the sequence with "__" (default={"", 0}}). /// be capped at n-1. left_pad=("_",2) would pad the left side of the sequence with "__" (default={"", 0}}).
/// \param[in] right_pad {"pad_token", pad_width}. Padding performed on right side of the sequence.pad_width will /// \param[in] right_pad {"pad_token", pad_width}. Padding performed on right side of the sequence.pad_width will
/// be capped at n-1. right_pad=("-":2) would pad the right side of the sequence with "--" (default={"", 0}}). /// be capped at n-1. right_pad=("-",2) would pad the right side of the sequence with "--" (default={"", 0}}).
/// \param[in] separator Symbol used to join strings together (default=" "). /// \param[in] separator Symbol used to join strings together (default=" ").
explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0}, explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0},
const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ") const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ")
@ -276,14 +276,13 @@ class NormalizeUTF8 final : public TensorTransform {
public: public:
/// \brief Constructor. /// \brief Constructor.
/// \param[in] normalize_form Valid values can be any of [NormalizeForm::kNone,NormalizeForm::kNfc, /// \param[in] normalize_form Valid values can be any of [NormalizeForm::kNone,NormalizeForm::kNfc,
/// NormalizeForm::kNfkc, /// NormalizeForm::kNfkc, NormalizeForm::kNfd, NormalizeForm::kNfkd](default=NormalizeForm::kNfkc).
/// NormalizeForm::kNfd, NormalizeForm::kNfkd](default=NormalizeForm::kNfkc).
/// See http://unicode.org/reports/tr15/ for details. /// See http://unicode.org/reports/tr15/ for details.
/// - NormalizeForm.NONE, remain the input string tensor unchanged. /// - NormalizeForm.kNone, remain the input string tensor unchanged.
/// - NormalizeForm.NFC, normalizes with Normalization Form C. /// - NormalizeForm.kNfc, normalizes with Normalization Form C.
/// - NormalizeForm.NFKC, normalizes with Normalization Form KC. /// - NormalizeForm.kNfkc, normalizes with Normalization Form KC.
/// - NormalizeForm.NFD, normalizes with Normalization Form D. /// - NormalizeForm.kNfd, normalizes with Normalization Form D.
/// - NormalizeForm.NFKD, normalizes with Normalization Form KD. /// - NormalizeForm.kNfkd, normalizes with Normalization Form KD.
explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc); explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc);
/// \brief Destructor /// \brief Destructor

View File

@ -79,7 +79,6 @@ def set_seed(seed):
If the seed is set, the generated random number will be fixed, this helps to If the seed is set, the generated random number will be fixed, this helps to
produce deterministic results. produce deterministic results.
Note: Note:
This set_seed function sets the seed in the Python random library and numpy.random library This set_seed function sets the seed in the Python random library and numpy.random library
for deterministic Python augmentations using randomness. This set_seed function should for deterministic Python augmentations using randomness. This set_seed function should
@ -113,6 +112,11 @@ def get_seed():
Returns: Returns:
int, random number seed. int, random number seed.
Examples:
>>> # Get the global configuration of seed.
>>> # If set_seed() is never called before, the default value(std::mt19937::default_seed) will be returned.
>>> seed = ds.config.get_seed()
""" """
return _config.get_seed() return _config.get_seed()
@ -147,6 +151,11 @@ def get_prefetch_size():
Returns: Returns:
int, total number of rows to be prefetched. int, total number of rows to be prefetched.
Examples:
>>> # Get the global configuration of prefetch size.
>>> # If set_prefetch_size() is never called before, the default value(20) will be returned.
>>> prefetch_size = ds.config.get_prefetch_size()
""" """
return _config.get_op_connector_size() return _config.get_op_connector_size()
@ -174,12 +183,17 @@ def set_num_parallel_workers(num):
def get_num_parallel_workers(): def get_num_parallel_workers():
""" """
Get the default number of parallel workers. Get the global configuration of number of parallel workers.
This is the DEFAULT num_parallel_workers value used for each operation, it is not related This is the DEFAULT num_parallel_workers value used for each operation, it is not related
to AutoNumWorker feature. to AutoNumWorker feature.
Returns: Returns:
int, number of parallel workers to be used as a default for each operation. int, number of parallel workers to be used as a default for each operation.
Examples:
>>> # Get the global configuration of parallel workers.
>>> # If set_num_parallel_workers() is never called before, the default value(8) will be returned.
>>> num_parallel_workers = ds.config.get_num_parallel_workers()
""" """
return _config.get_num_parallel_workers() return _config.get_num_parallel_workers()
@ -206,11 +220,15 @@ def set_numa_enable(numa_enable):
def get_numa_enable(): def get_numa_enable():
""" """
Get the default state of numa enabled. Get the state of numa to indicate enabled/disabled.
This is the DEFAULT numa enabled value used for the all process. This is the DEFAULT numa enabled value used for the all process.
Returns: Returns:
bool, the default state of numa enabled. bool, the default state of numa enabled.
Examples:
>>> # Get the global configuration of numa.
>>> numa_state = ds.config.get_numa_enable()
""" """
return _config.get_numa_enable() return _config.get_numa_enable()
@ -236,10 +254,15 @@ def set_monitor_sampling_interval(interval):
def get_monitor_sampling_interval(): def get_monitor_sampling_interval():
""" """
Get the default interval of performance monitor sampling. Get the global configuration of sampling interval of performance monitor.
Returns: Returns:
int, interval (in milliseconds) for performance monitor sampling. int, interval (in milliseconds) for performance monitor sampling.
Examples:
>>> # Get the global configuration of monitor sampling interval.
>>> # If set_monitor_sampling_interval() is never called before, the default value(1000) will be returned.
>>> ds.config.get_monitor_sampling_interval()
""" """
return _config.get_monitor_sampling_interval() return _config.get_monitor_sampling_interval()
@ -299,9 +322,10 @@ def get_auto_num_workers():
Get the setting (turned on or off) automatic number of workers. Get the setting (turned on or off) automatic number of workers.
Returns: Returns:
bool, whether auto num worker feature is turned on. bool, whether auto number worker feature is turned on.
Examples: Examples:
>>> # Get the global configuration of auto number worker feature.
>>> num_workers = ds.config.get_auto_num_workers() >>> num_workers = ds.config.get_auto_num_workers()
""" """
return _config.get_auto_num_workers() return _config.get_auto_num_workers()
@ -334,6 +358,11 @@ def get_callback_timeout():
Returns: Returns:
int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
Examples:
>>> # Get the global configuration of callback timeout.
>>> # If set_callback_timeout() is never called before, the default value(60) will be returned.
>>> ds.config.get_callback_timeout()
""" """
return _config.get_callback_timeout() return _config.get_callback_timeout()
@ -394,6 +423,10 @@ def get_enable_shared_mem():
Returns: Returns:
bool, the state of shared mem enabled variable (default=True). bool, the state of shared mem enabled variable (default=True).
Examples:
>>> # Get the flag of shared memory feature.
>>> shared_mem_flag = ds.config.get_enable_shared_mem()
""" """
return _config.get_enable_shared_mem() return _config.get_enable_shared_mem()
@ -410,12 +443,14 @@ def set_enable_shared_mem(enable):
TypeError: If enable is not a boolean data type. TypeError: If enable is not a boolean data type.
Examples: Examples:
>>> # Enable shared memory feature to improve the performance of Python multiprocessing.
>>> ds.config.set_enable_shared_mem(True) >>> ds.config.set_enable_shared_mem(True)
""" """
if not isinstance(enable, bool): if not isinstance(enable, bool):
raise TypeError("enable must be of type bool.") raise TypeError("enable must be of type bool.")
_config.set_enable_shared_mem(enable) _config.set_enable_shared_mem(enable)
def set_sending_batches(batch_num): def set_sending_batches(batch_num):
""" """
Set the default sending batches when training with sink_mode=True in Ascend device. Set the default sending batches when training with sink_mode=True in Ascend device.

View File

@ -334,7 +334,7 @@ class Dataset:
Serialize a pipeline into JSON string and dump into file if filename is provided. Serialize a pipeline into JSON string and dump into file if filename is provided.
Args: Args:
filename (str): filename of json file to be saved as filename (str): filename of JSON file to be saved as.
Returns: Returns:
str, JSON string of the pipeline. str, JSON string of the pipeline.
@ -1511,7 +1511,7 @@ class Dataset:
def get_col_names(self): def get_col_names(self):
""" """
Renturn the names of the columns in dataset. Return the names of the columns in dataset.
Returns: Returns:
list, list of column names in the dataset. list, list of column names in the dataset.
@ -1582,7 +1582,7 @@ class Dataset:
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
""" """
Get minimum and maximum data length of dynamic source data, for graph compilation of ME. Get minimum and maximum data length of dynamic source data, for dynamic graph compilation.
Returns: Returns:
lists, min_shapes, max_shapes of source data. lists, min_shapes, max_shapes of source data.
@ -2187,7 +2187,7 @@ class BatchDataset(Dataset):
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list) self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list)
self.hook = _ExceptHookHandler() self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess) atexit.register(_mp_pool_exit_preprocess)
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. # If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
atexit.register(self.process_pool.close) atexit.register(self.process_pool.close)
else: else:
@ -2682,7 +2682,7 @@ class MapDataset(Dataset):
self.operations = iter_specific_operations self.operations = iter_specific_operations
self.hook = _ExceptHookHandler() self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess) atexit.register(_mp_pool_exit_preprocess)
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. # If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
atexit.register(self.process_pool.close) atexit.register(self.process_pool.close)
@ -3002,7 +3002,7 @@ class TransferDataset(Dataset):
input_dataset (Dataset): Input Dataset to be transferred. input_dataset (Dataset): Input Dataset to be transferred.
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True). send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
create_data_info_queue (bool, optional): Whether to create queue which stores create_data_info_queue (bool, optional): Whether to create queue which stores
types and shapes of data or not(default=False). types and shapes of data or not (default=False).
Raises: Raises:
TypeError: If device_type is empty. TypeError: If device_type is empty.
@ -4798,12 +4798,12 @@ class VOCDataset(MappableDataset):
title = {The Pascal Visual Object Classes (VOC) Challenge}, title = {The Pascal Visual Object Classes (VOC) Challenge},
journal = {International Journal of Computer Vision}, journal = {International Journal of Computer Vision},
volume = {88}, volume = {88},
year = {2010}, year = {2012},
number = {2}, number = {2},
month = {jun}, month = {jun},
pages = {303--338}, pages = {303--338},
biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex}, biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc{year}/index.html} howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html}
} }
""" """
@ -4959,10 +4959,11 @@ class CocoDataset(MappableDataset):
About COCO dataset: About COCO dataset:
COCO is a large-scale object detection, segmentation, and captioning dataset. COCO(Microsoft Common Objects in Context) is a large-scale object detection, segmentation, and captioning dataset
It contains 91 common object categories with 82 of them having more than 5,000 with several features: Object segmentation, Recognition in context, Superpixel stuff segmentation,
labeled instances. In contrast to the popular ImageNet dataset, COCO has fewer 330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories,
categories but more instances per category. 5 captions per image, 250,000 people with keypoints. In contrast to the popular ImageNet dataset, COCO has fewer
categories but more instances in per category.
You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API. You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API.
@ -5304,7 +5305,7 @@ class CLUEDataset(SourceDataset):
About CLUE dataset: About CLUE dataset:
CLUE, a Chinese Language Understanding Evaluation benchmark. It contains eight different CLUE, a Chinese Language Understanding Evaluation benchmark. It contains multiple
tasks, including single-sentence classification, sentence pair classification, and machine tasks, including single-sentence classification, sentence pair classification, and machine
reading comprehension. reading comprehension.

View File

@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
""" """
This dataset module creates an internal queue class to more optimally pass data This dataset module creates an internal queue class to more optimally pass data
between multiple processes in python. It has same API as multiprocessing.queue between multiple processes in Python. It has same API as multiprocessing.queue
but it will pass large data through shared memory. but it will pass large data through shared memory.
""" """

View File

@ -121,29 +121,29 @@ class BuiltinSampler:
self.child_sampler = sampler self.child_sampler = sampler
def get_child(self): def get_child(self):
""" add a child sampler. """ """ Get the child sampler. """
return self.child_sampler return self.child_sampler
def parse_child(self): def parse_child(self):
"""Parse the child sampler.""" """ Parse the child sampler. """
c_child_sampler = None c_child_sampler = None
if self.child_sampler is not None: if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse() c_child_sampler = self.child_sampler.parse()
return c_child_sampler return c_child_sampler
def parse_child_for_minddataset(self): def parse_child_for_minddataset(self):
"""Parse the child sampler for MindRecord.""" """ Parse the child sampler for MindRecord. """
c_child_sampler = None c_child_sampler = None
if self.child_sampler is not None: if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse_for_minddataset() c_child_sampler = self.child_sampler.parse_for_minddataset()
return c_child_sampler return c_child_sampler
def is_shuffled(self): def is_shuffled(self):
""" not implemented """ """ Not implemented. """
raise NotImplementedError("Sampler must implement is_shuffled.") raise NotImplementedError("Sampler must implement is_shuffled.")
def is_sharded(self): def is_sharded(self):
""" not implemented """ """ Not implemented. """
raise NotImplementedError("Sampler must implement is_sharded.") raise NotImplementedError("Sampler must implement is_sharded.")
def get_num_samples(self): def get_num_samples(self):
@ -313,8 +313,10 @@ class DistributedSampler(BuiltinSampler):
shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1]. shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True). shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements). num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
should be no more than num_shards. should be no more than num_shards. This parameter is only valid when a ConcatDataset takes
a DistributedSampler as its sampler. It will affect the number of samples of per shard
(default=-1, which means each shard has same number of samples).
Examples: Examples:
>>> # creates a distributed sampler with 10 shards in total. This shard is shard 5. >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
@ -329,9 +331,9 @@ class DistributedSampler(BuiltinSampler):
TypeError: If shuffle is not a boolean value. TypeError: If shuffle is not a boolean value.
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
TypeError: If offset is not an integer value. TypeError: If offset is not an integer value.
ValueError: If num_samples is a negative value.
RuntimeError: If num_shards is not a positive value. RuntimeError: If num_shards is not a positive value.
RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards. RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
RuntimeError: If num_samples is a negative value.
RuntimeError: If offset is greater than num_shards. RuntimeError: If offset is greater than num_shards.
""" """
@ -411,7 +413,7 @@ class PKSampler(BuiltinSampler):
num_class (int, optional): Number of classes to sample (default=None, sample all classes). num_class (int, optional): Number of classes to sample (default=None, sample all classes).
The parameter does not supported to specify currently. The parameter does not supported to specify currently.
shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be
shuffled(default=False). shuffled (default=False).
class_column (str, optional): Name of column with class labels for MindDataset (default='label'). class_column (str, optional): Name of column with class labels for MindDataset (default='label').
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements). num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
@ -423,13 +425,12 @@ class PKSampler(BuiltinSampler):
... sampler=sampler) ... sampler=sampler)
Raises: Raises:
TypeError: If num_val is not a positive value.
TypeError: If shuffle is not a boolean value. TypeError: If shuffle is not a boolean value.
TypeError: If class_column is not a str value. TypeError: If class_column is not a str value.
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
NotImplementedError: If num_class is not None. NotImplementedError: If num_class is not None.
RuntimeError: If num_val is not a positive value. RuntimeError: If num_val is not a positive value.
RuntimeError: If num_samples is a negative value. ValueError: If num_samples is a negative value.
""" """
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
@ -508,7 +509,7 @@ class RandomSampler(BuiltinSampler):
Raises: Raises:
TypeError: If replacement is not a boolean value. TypeError: If replacement is not a boolean value.
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
RuntimeError: If num_samples is a negative value. ValueError: If num_samples is a negative value.
""" """
def __init__(self, replacement=False, num_samples=None): def __init__(self, replacement=False, num_samples=None):
@ -573,7 +574,7 @@ class SequentialSampler(BuiltinSampler):
TypeError: If start_index is not an integer value. TypeError: If start_index is not an integer value.
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
RuntimeError: If start_index is a negative value. RuntimeError: If start_index is a negative value.
RuntimeError: If num_samples is a negative value. ValueError: If num_samples is a negative value.
""" """
def __init__(self, start_index=None, num_samples=None): def __init__(self, start_index=None, num_samples=None):
@ -641,7 +642,7 @@ class SubsetSampler(BuiltinSampler):
Raises: Raises:
TypeError: If type of indices element is not a number. TypeError: If type of indices element is not a number.
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
RuntimeError: If num_samples is a negative value. ValueError: If num_samples is a negative value.
""" """
def __init__(self, indices, num_samples=None): def __init__(self, indices, num_samples=None):
@ -713,7 +714,7 @@ class SubsetRandomSampler(SubsetSampler):
Samples the elements randomly from a sequence of indices. Samples the elements randomly from a sequence of indices.
Args: Args:
indices (Any iterable python object but string): A sequence of indices. indices (Any iterable Python object but string): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
Examples: Examples:
@ -726,7 +727,7 @@ class SubsetRandomSampler(SubsetSampler):
Raises: Raises:
TypeError: If type of indices element is not a number. TypeError: If type of indices element is not a number.
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
RuntimeError: If num_samples is a negative value. ValueError: If num_samples is a negative value.
""" """
def parse(self): def parse(self):
@ -806,7 +807,7 @@ class WeightedRandomSampler(BuiltinSampler):
TypeError: If num_samples is not an integer value. TypeError: If num_samples is not an integer value.
TypeError: If replacement is not a boolean value. TypeError: If replacement is not a boolean value.
RuntimeError: If weights is empty or all zero. RuntimeError: If weights is empty or all zero.
RuntimeError: If num_samples is a negative value. ValueError: If num_samples is a negative value.
""" """
def __init__(self, weights, num_samples=None, replacement=True): def __init__(self, weights, num_samples=None, replacement=True):

View File

@ -27,15 +27,15 @@ from ..vision.utils import Inter, Border, ImageBatchFormat
def serialize(dataset, json_filepath=""): def serialize(dataset, json_filepath=""):
""" """
Serialize dataset pipeline into a json file. Serialize dataset pipeline into a JSON file.
Note: Note:
Currently some python objects are not supported to be serialized. Currently some Python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name. For Python function serialization of map operator, de.serialize will only return its function name.
Args: Args:
dataset (Dataset): The starting node. dataset (Dataset): The starting node.
json_filepath (str): The filepath where a serialized json file will be generated. json_filepath (str): The filepath where a serialized JSON file will be generated.
Returns: Returns:
Dict, The dictionary contains the serialized dataset graph. Dict, The dictionary contains the serialized dataset graph.
@ -48,7 +48,7 @@ def serialize(dataset, json_filepath=""):
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument >>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> # serialize it to json file >>> # serialize it to JSON file
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") >>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict >>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict
""" """
@ -57,27 +57,27 @@ def serialize(dataset, json_filepath=""):
def deserialize(input_dict=None, json_filepath=None): def deserialize(input_dict=None, json_filepath=None):
""" """
Construct a de pipeline from a json file produced by de.serialize(). Construct a de pipeline from a JSON file produced by de.serialize().
Note: Note:
Currently python function deserialization of map operator are not supported. Currently Python function deserialization of map operator are not supported.
Args: Args:
input_dict (dict): A Python dictionary containing a serialized dataset graph. input_dict (dict): A Python dictionary containing a serialized dataset graph.
json_filepath (str): A path to the json file. json_filepath (str): A path to the JSON file.
Returns: Returns:
de.Dataset or None if error occurs. de.Dataset or None if error occurs.
Raises: Raises:
OSError: Can not open the json file. OSError: Can not open the JSON file.
Examples: Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100) >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument >>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True) >>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> # Use case 1: to/from json file >>> # Use case 1: to/from JSON file
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") >>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json") >>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> # Use case 2: to/from Python dictionary >>> # Use case 2: to/from Python dictionary
@ -113,8 +113,15 @@ def show(dataset, indentation=2):
Args: Args:
dataset (Dataset): The starting node. dataset (Dataset): The starting node.
indentation (int, optional): The indentation used by the json print. indentation (int, optional): The indentation used by the JSON print.
Do not indent if indentation is None. Do not indent if indentation is None.
Examples:
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
>>> one_hot_encode = c_transforms.OneHot(10)
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> ds.show(dataset)
""" """
pipeline = dataset.to_json() pipeline = dataset.to_json()
@ -128,13 +135,21 @@ def compare(pipeline1, pipeline2):
Args: Args:
pipeline1 (Dataset): a dataset pipeline. pipeline1 (Dataset): a dataset pipeline.
pipeline2 (Dataset): a dataset pipeline. pipeline2 (Dataset): a dataset pipeline.
Returns:
Whether pipeline1 is equal to pipeline2.
Examples:
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
>>> ds.compare(pipeline1, pipeline2)
""" """
return pipeline1.to_json() == pipeline2.to_json() return pipeline1.to_json() == pipeline2.to_json()
def construct_pipeline(node): def construct_pipeline(node):
"""Construct the Python Dataset objects by following the dictionary deserialized from json file.""" """Construct the Python Dataset objects by following the dictionary deserialized from JSON file."""
op_type = node.get('op_type') op_type = node.get('op_type')
if not op_type: if not op_type:
raise ValueError("op_type field in the json file can't be None.") raise ValueError("op_type field in the json file can't be None.")