forked from mindspore-Ecosystem/mindspore
!19765 Fix docs issues
Merge pull request !19765 from luoyang/code_docs_python-doc
This commit is contained in:
commit
f86644358e
|
@ -117,7 +117,6 @@ class BertTokenizer final : public TensorTransform {
|
|||
};
|
||||
|
||||
/// \brief Apply case fold operation on UTF-8 string tensors.
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
class CaseFold final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
|
@ -142,7 +141,8 @@ class JiebaTokenizer final : public TensorTransform {
|
|||
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
|
||||
/// \param[in] mp_path Dictionary file is used by the MPSegment algorithm. The dictionary can be obtained on the
|
||||
/// official website of cppjieba (https://github.com/yanyiwu/cppjieba).
|
||||
/// \param[in] mode Valid values can be any of JiebaMode.MP, JiebaMode.HMM and JiebaMode.MIX (default=JiebaMode.MIX).
|
||||
/// \param[in] mode Valid values can be any of JiebaMode.kMP, JiebaMode.kHMM and JiebaMode.kMIX
|
||||
/// (default=JiebaMode.kMIX).
|
||||
/// - JiebaMode.kMP, tokenizes with MPSegment algorithm.
|
||||
/// - JiebaMode.kHMM, tokenizes with Hidden Markov Model Segment algorithm.
|
||||
/// - JiebaMode.kMIX, tokenizes with a mix of MPSegment and HMMSegment algorithms.
|
||||
|
@ -248,7 +248,7 @@ class Ngram final : public TensorTransform {
|
|||
/// \param[in] left_pad {"pad_token", pad_width}. Padding performed on left side of the sequence. pad_width will
|
||||
/// be capped at n-1. left_pad=("_",2) would pad the left side of the sequence with "__" (default={"", 0}}).
|
||||
/// \param[in] right_pad {"pad_token", pad_width}. Padding performed on right side of the sequence.pad_width will
|
||||
/// be capped at n-1. right_pad=("-":2) would pad the right side of the sequence with "--" (default={"", 0}}).
|
||||
/// be capped at n-1. right_pad=("-",2) would pad the right side of the sequence with "--" (default={"", 0}}).
|
||||
/// \param[in] separator Symbol used to join strings together (default=" ").
|
||||
explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0},
|
||||
const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ")
|
||||
|
@ -276,14 +276,13 @@ class NormalizeUTF8 final : public TensorTransform {
|
|||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] normalize_form Valid values can be any of [NormalizeForm::kNone,NormalizeForm::kNfc,
|
||||
/// NormalizeForm::kNfkc,
|
||||
/// NormalizeForm::kNfd, NormalizeForm::kNfkd](default=NormalizeForm::kNfkc).
|
||||
/// NormalizeForm::kNfkc, NormalizeForm::kNfd, NormalizeForm::kNfkd](default=NormalizeForm::kNfkc).
|
||||
/// See http://unicode.org/reports/tr15/ for details.
|
||||
/// - NormalizeForm.NONE, remain the input string tensor unchanged.
|
||||
/// - NormalizeForm.NFC, normalizes with Normalization Form C.
|
||||
/// - NormalizeForm.NFKC, normalizes with Normalization Form KC.
|
||||
/// - NormalizeForm.NFD, normalizes with Normalization Form D.
|
||||
/// - NormalizeForm.NFKD, normalizes with Normalization Form KD.
|
||||
/// - NormalizeForm.kNone, remain the input string tensor unchanged.
|
||||
/// - NormalizeForm.kNfc, normalizes with Normalization Form C.
|
||||
/// - NormalizeForm.kNfkc, normalizes with Normalization Form KC.
|
||||
/// - NormalizeForm.kNfd, normalizes with Normalization Form D.
|
||||
/// - NormalizeForm.kNfkd, normalizes with Normalization Form KD.
|
||||
explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc);
|
||||
|
||||
/// \brief Destructor
|
||||
|
|
|
@ -79,7 +79,6 @@ def set_seed(seed):
|
|||
If the seed is set, the generated random number will be fixed, this helps to
|
||||
produce deterministic results.
|
||||
|
||||
|
||||
Note:
|
||||
This set_seed function sets the seed in the Python random library and numpy.random library
|
||||
for deterministic Python augmentations using randomness. This set_seed function should
|
||||
|
@ -113,6 +112,11 @@ def get_seed():
|
|||
|
||||
Returns:
|
||||
int, random number seed.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of seed.
|
||||
>>> # If set_seed() is never called before, the default value(std::mt19937::default_seed) will be returned.
|
||||
>>> seed = ds.config.get_seed()
|
||||
"""
|
||||
return _config.get_seed()
|
||||
|
||||
|
@ -147,6 +151,11 @@ def get_prefetch_size():
|
|||
|
||||
Returns:
|
||||
int, total number of rows to be prefetched.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of prefetch size.
|
||||
>>> # If set_prefetch_size() is never called before, the default value(20) will be returned.
|
||||
>>> prefetch_size = ds.config.get_prefetch_size()
|
||||
"""
|
||||
return _config.get_op_connector_size()
|
||||
|
||||
|
@ -174,12 +183,17 @@ def set_num_parallel_workers(num):
|
|||
|
||||
def get_num_parallel_workers():
|
||||
"""
|
||||
Get the default number of parallel workers.
|
||||
Get the global configuration of number of parallel workers.
|
||||
This is the DEFAULT num_parallel_workers value used for each operation, it is not related
|
||||
to AutoNumWorker feature.
|
||||
|
||||
Returns:
|
||||
int, number of parallel workers to be used as a default for each operation.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of parallel workers.
|
||||
>>> # If set_num_parallel_workers() is never called before, the default value(8) will be returned.
|
||||
>>> num_parallel_workers = ds.config.get_num_parallel_workers()
|
||||
"""
|
||||
return _config.get_num_parallel_workers()
|
||||
|
||||
|
@ -206,11 +220,15 @@ def set_numa_enable(numa_enable):
|
|||
|
||||
def get_numa_enable():
|
||||
"""
|
||||
Get the default state of numa enabled.
|
||||
Get the state of numa to indicate enabled/disabled.
|
||||
This is the DEFAULT numa enabled value used for the all process.
|
||||
|
||||
Returns:
|
||||
bool, the default state of numa enabled.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of numa.
|
||||
>>> numa_state = ds.config.get_numa_enable()
|
||||
"""
|
||||
return _config.get_numa_enable()
|
||||
|
||||
|
@ -236,10 +254,15 @@ def set_monitor_sampling_interval(interval):
|
|||
|
||||
def get_monitor_sampling_interval():
|
||||
"""
|
||||
Get the default interval of performance monitor sampling.
|
||||
Get the global configuration of sampling interval of performance monitor.
|
||||
|
||||
Returns:
|
||||
int, interval (in milliseconds) for performance monitor sampling.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of monitor sampling interval.
|
||||
>>> # If set_monitor_sampling_interval() is never called before, the default value(1000) will be returned.
|
||||
>>> ds.config.get_monitor_sampling_interval()
|
||||
"""
|
||||
return _config.get_monitor_sampling_interval()
|
||||
|
||||
|
@ -299,9 +322,10 @@ def get_auto_num_workers():
|
|||
Get the setting (turned on or off) automatic number of workers.
|
||||
|
||||
Returns:
|
||||
bool, whether auto num worker feature is turned on.
|
||||
bool, whether auto number worker feature is turned on.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of auto number worker feature.
|
||||
>>> num_workers = ds.config.get_auto_num_workers()
|
||||
"""
|
||||
return _config.get_auto_num_workers()
|
||||
|
@ -334,6 +358,11 @@ def get_callback_timeout():
|
|||
|
||||
Returns:
|
||||
int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of callback timeout.
|
||||
>>> # If set_callback_timeout() is never called before, the default value(60) will be returned.
|
||||
>>> ds.config.get_callback_timeout()
|
||||
"""
|
||||
return _config.get_callback_timeout()
|
||||
|
||||
|
@ -394,6 +423,10 @@ def get_enable_shared_mem():
|
|||
|
||||
Returns:
|
||||
bool, the state of shared mem enabled variable (default=True).
|
||||
|
||||
Examples:
|
||||
>>> # Get the flag of shared memory feature.
|
||||
>>> shared_mem_flag = ds.config.get_enable_shared_mem()
|
||||
"""
|
||||
return _config.get_enable_shared_mem()
|
||||
|
||||
|
@ -410,12 +443,14 @@ def set_enable_shared_mem(enable):
|
|||
TypeError: If enable is not a boolean data type.
|
||||
|
||||
Examples:
|
||||
>>> # Enable shared memory feature to improve the performance of Python multiprocessing.
|
||||
>>> ds.config.set_enable_shared_mem(True)
|
||||
"""
|
||||
if not isinstance(enable, bool):
|
||||
raise TypeError("enable must be of type bool.")
|
||||
_config.set_enable_shared_mem(enable)
|
||||
|
||||
|
||||
def set_sending_batches(batch_num):
|
||||
"""
|
||||
Set the default sending batches when training with sink_mode=True in Ascend device.
|
||||
|
|
|
@ -334,7 +334,7 @@ class Dataset:
|
|||
Serialize a pipeline into JSON string and dump into file if filename is provided.
|
||||
|
||||
Args:
|
||||
filename (str): filename of json file to be saved as
|
||||
filename (str): filename of JSON file to be saved as.
|
||||
|
||||
Returns:
|
||||
str, JSON string of the pipeline.
|
||||
|
@ -1511,7 +1511,7 @@ class Dataset:
|
|||
|
||||
def get_col_names(self):
|
||||
"""
|
||||
Renturn the names of the columns in dataset.
|
||||
Return the names of the columns in dataset.
|
||||
|
||||
Returns:
|
||||
list, list of column names in the dataset.
|
||||
|
@ -1582,7 +1582,7 @@ class Dataset:
|
|||
|
||||
def dynamic_min_max_shapes(self):
|
||||
"""
|
||||
Get minimum and maximum data length of dynamic source data, for graph compilation of ME.
|
||||
Get minimum and maximum data length of dynamic source data, for dynamic graph compilation.
|
||||
|
||||
Returns:
|
||||
lists, min_shapes, max_shapes of source data.
|
||||
|
@ -2187,7 +2187,7 @@ class BatchDataset(Dataset):
|
|||
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list)
|
||||
self.hook = _ExceptHookHandler()
|
||||
atexit.register(_mp_pool_exit_preprocess)
|
||||
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
||||
# If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
||||
if sys.version_info >= (3, 8):
|
||||
atexit.register(self.process_pool.close)
|
||||
else:
|
||||
|
@ -2682,7 +2682,7 @@ class MapDataset(Dataset):
|
|||
self.operations = iter_specific_operations
|
||||
self.hook = _ExceptHookHandler()
|
||||
atexit.register(_mp_pool_exit_preprocess)
|
||||
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
||||
# If Python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
|
||||
if sys.version_info >= (3, 8):
|
||||
atexit.register(self.process_pool.close)
|
||||
|
||||
|
@ -4798,12 +4798,12 @@ class VOCDataset(MappableDataset):
|
|||
title = {The Pascal Visual Object Classes (VOC) Challenge},
|
||||
journal = {International Journal of Computer Vision},
|
||||
volume = {88},
|
||||
year = {2010},
|
||||
year = {2012},
|
||||
number = {2},
|
||||
month = {jun},
|
||||
pages = {303--338},
|
||||
biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
|
||||
howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc{year}/index.html}
|
||||
howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html}
|
||||
}
|
||||
"""
|
||||
|
||||
|
@ -4959,10 +4959,11 @@ class CocoDataset(MappableDataset):
|
|||
|
||||
About COCO dataset:
|
||||
|
||||
COCO is a large-scale object detection, segmentation, and captioning dataset.
|
||||
It contains 91 common object categories with 82 of them having more than 5,000
|
||||
labeled instances. In contrast to the popular ImageNet dataset, COCO has fewer
|
||||
categories but more instances per category.
|
||||
COCO(Microsoft Common Objects in Context) is a large-scale object detection, segmentation, and captioning dataset
|
||||
with several features: Object segmentation, Recognition in context, Superpixel stuff segmentation,
|
||||
330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories,
|
||||
5 captions per image, 250,000 people with keypoints. In contrast to the popular ImageNet dataset, COCO has fewer
|
||||
categories but more instances in per category.
|
||||
|
||||
You can unzip the original COCO-2017 dataset files into this directory structure and read by MindSpore's API.
|
||||
|
||||
|
@ -5304,7 +5305,7 @@ class CLUEDataset(SourceDataset):
|
|||
|
||||
About CLUE dataset:
|
||||
|
||||
CLUE, a Chinese Language Understanding Evaluation benchmark. It contains eight different
|
||||
CLUE, a Chinese Language Understanding Evaluation benchmark. It contains multiple
|
||||
tasks, including single-sentence classification, sentence pair classification, and machine
|
||||
reading comprehension.
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""
|
||||
This dataset module creates an internal queue class to more optimally pass data
|
||||
between multiple processes in python. It has same API as multiprocessing.queue
|
||||
between multiple processes in Python. It has same API as multiprocessing.queue
|
||||
but it will pass large data through shared memory.
|
||||
"""
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ class BuiltinSampler:
|
|||
self.child_sampler = sampler
|
||||
|
||||
def get_child(self):
|
||||
""" add a child sampler. """
|
||||
""" Get the child sampler. """
|
||||
return self.child_sampler
|
||||
|
||||
def parse_child(self):
|
||||
|
@ -139,11 +139,11 @@ class BuiltinSampler:
|
|||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
""" not implemented """
|
||||
""" Not implemented. """
|
||||
raise NotImplementedError("Sampler must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
""" not implemented """
|
||||
""" Not implemented. """
|
||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||
|
||||
def get_num_samples(self):
|
||||
|
@ -313,8 +313,10 @@ class DistributedSampler(BuiltinSampler):
|
|||
shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
|
||||
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
|
||||
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
|
||||
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
|
||||
should be no more than num_shards.
|
||||
offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
|
||||
should be no more than num_shards. This parameter is only valid when a ConcatDataset takes
|
||||
a DistributedSampler as its sampler. It will affect the number of samples of per shard
|
||||
(default=-1, which means each shard has same number of samples).
|
||||
|
||||
Examples:
|
||||
>>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
|
||||
|
@ -329,9 +331,9 @@ class DistributedSampler(BuiltinSampler):
|
|||
TypeError: If shuffle is not a boolean value.
|
||||
TypeError: If num_samples is not an integer value.
|
||||
TypeError: If offset is not an integer value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
RuntimeError: If num_shards is not a positive value.
|
||||
RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
RuntimeError: If offset is greater than num_shards.
|
||||
"""
|
||||
|
||||
|
@ -423,13 +425,12 @@ class PKSampler(BuiltinSampler):
|
|||
... sampler=sampler)
|
||||
|
||||
Raises:
|
||||
TypeError: If num_val is not a positive value.
|
||||
TypeError: If shuffle is not a boolean value.
|
||||
TypeError: If class_column is not a str value.
|
||||
TypeError: If num_samples is not an integer value.
|
||||
NotImplementedError: If num_class is not None.
|
||||
RuntimeError: If num_val is not a positive value.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
"""
|
||||
|
||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
||||
|
@ -508,7 +509,7 @@ class RandomSampler(BuiltinSampler):
|
|||
Raises:
|
||||
TypeError: If replacement is not a boolean value.
|
||||
TypeError: If num_samples is not an integer value.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
"""
|
||||
|
||||
def __init__(self, replacement=False, num_samples=None):
|
||||
|
@ -573,7 +574,7 @@ class SequentialSampler(BuiltinSampler):
|
|||
TypeError: If start_index is not an integer value.
|
||||
TypeError: If num_samples is not an integer value.
|
||||
RuntimeError: If start_index is a negative value.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
"""
|
||||
|
||||
def __init__(self, start_index=None, num_samples=None):
|
||||
|
@ -641,7 +642,7 @@ class SubsetSampler(BuiltinSampler):
|
|||
Raises:
|
||||
TypeError: If type of indices element is not a number.
|
||||
TypeError: If num_samples is not an integer value.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
"""
|
||||
|
||||
def __init__(self, indices, num_samples=None):
|
||||
|
@ -713,7 +714,7 @@ class SubsetRandomSampler(SubsetSampler):
|
|||
Samples the elements randomly from a sequence of indices.
|
||||
|
||||
Args:
|
||||
indices (Any iterable python object but string): A sequence of indices.
|
||||
indices (Any iterable Python object but string): A sequence of indices.
|
||||
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
|
||||
|
||||
Examples:
|
||||
|
@ -726,7 +727,7 @@ class SubsetRandomSampler(SubsetSampler):
|
|||
Raises:
|
||||
TypeError: If type of indices element is not a number.
|
||||
TypeError: If num_samples is not an integer value.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
"""
|
||||
|
||||
def parse(self):
|
||||
|
@ -806,7 +807,7 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
TypeError: If num_samples is not an integer value.
|
||||
TypeError: If replacement is not a boolean value.
|
||||
RuntimeError: If weights is empty or all zero.
|
||||
RuntimeError: If num_samples is a negative value.
|
||||
ValueError: If num_samples is a negative value.
|
||||
"""
|
||||
|
||||
def __init__(self, weights, num_samples=None, replacement=True):
|
||||
|
|
|
@ -27,15 +27,15 @@ from ..vision.utils import Inter, Border, ImageBatchFormat
|
|||
|
||||
def serialize(dataset, json_filepath=""):
|
||||
"""
|
||||
Serialize dataset pipeline into a json file.
|
||||
Serialize dataset pipeline into a JSON file.
|
||||
|
||||
Note:
|
||||
Currently some python objects are not supported to be serialized.
|
||||
For python function serialization of map operator, de.serialize will only return its function name.
|
||||
Currently some Python objects are not supported to be serialized.
|
||||
For Python function serialization of map operator, de.serialize will only return its function name.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The starting node.
|
||||
json_filepath (str): The filepath where a serialized json file will be generated.
|
||||
json_filepath (str): The filepath where a serialized JSON file will be generated.
|
||||
|
||||
Returns:
|
||||
Dict, The dictionary contains the serialized dataset graph.
|
||||
|
@ -48,7 +48,7 @@ def serialize(dataset, json_filepath=""):
|
|||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> # serialize it to json file
|
||||
>>> # serialize it to JSON file
|
||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> serialized_data = ds.engine.serialize(dataset) # serialize it to Python dict
|
||||
"""
|
||||
|
@ -57,27 +57,27 @@ def serialize(dataset, json_filepath=""):
|
|||
|
||||
def deserialize(input_dict=None, json_filepath=None):
|
||||
"""
|
||||
Construct a de pipeline from a json file produced by de.serialize().
|
||||
Construct a de pipeline from a JSON file produced by de.serialize().
|
||||
|
||||
Note:
|
||||
Currently python function deserialization of map operator are not supported.
|
||||
Currently Python function deserialization of map operator are not supported.
|
||||
|
||||
Args:
|
||||
input_dict (dict): A Python dictionary containing a serialized dataset graph.
|
||||
json_filepath (str): A path to the json file.
|
||||
json_filepath (str): A path to the JSON file.
|
||||
|
||||
Returns:
|
||||
de.Dataset or None if error occurs.
|
||||
|
||||
Raises:
|
||||
OSError: Can not open the json file.
|
||||
OSError: Can not open the JSON file.
|
||||
|
||||
Examples:
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> # Use case 1: to/from json file
|
||||
>>> # Use case 1: to/from JSON file
|
||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> # Use case 2: to/from Python dictionary
|
||||
|
@ -113,8 +113,15 @@ def show(dataset, indentation=2):
|
|||
|
||||
Args:
|
||||
dataset (Dataset): The starting node.
|
||||
indentation (int, optional): The indentation used by the json print.
|
||||
indentation (int, optional): The indentation used by the JSON print.
|
||||
Do not indent if indentation is None.
|
||||
|
||||
Examples:
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10)
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> ds.show(dataset)
|
||||
"""
|
||||
|
||||
pipeline = dataset.to_json()
|
||||
|
@ -128,13 +135,21 @@ def compare(pipeline1, pipeline2):
|
|||
Args:
|
||||
pipeline1 (Dataset): a dataset pipeline.
|
||||
pipeline2 (Dataset): a dataset pipeline.
|
||||
|
||||
Returns:
|
||||
Whether pipeline1 is equal to pipeline2.
|
||||
|
||||
Examples:
|
||||
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
|
||||
>>> ds.compare(pipeline1, pipeline2)
|
||||
"""
|
||||
|
||||
return pipeline1.to_json() == pipeline2.to_json()
|
||||
|
||||
|
||||
def construct_pipeline(node):
|
||||
"""Construct the Python Dataset objects by following the dictionary deserialized from json file."""
|
||||
"""Construct the Python Dataset objects by following the dictionary deserialized from JSON file."""
|
||||
op_type = node.get('op_type')
|
||||
if not op_type:
|
||||
raise ValueError("op_type field in the json file can't be None.")
|
||||
|
|
Loading…
Reference in New Issue