!35081 fix-output_shape_doc

Merge pull request !35081 from luoyang/fix-output_shape_doc
This commit is contained in:
i-robot 2022-05-30 01:22:29 +00:00 committed by Gitee
commit 57f66960ae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 42 additions and 10 deletions

View File

@ -203,6 +203,11 @@
获取数据集对象中每列数据的shape。
**参数:**
- **estimate** (bool) - 如果 `estimate` 为 False将返回数据集第一条数据的shape。
否则将遍历整个数据集以获取数据集的真实shape信息其中动态变化的维度将被标记为-1可用于动态shape数据集场景
**返回:**
list每列数据的shape列表。

View File

@ -641,14 +641,13 @@ def set_enable_shared_mem(enable):
>>> # Enable shared memory feature to improve the performance of Python multiprocessing.
>>> ds.config.set_enable_shared_mem(True)
"""
# For Windows and MacOS we forbid shared mem function temporarily
if platform.system().lower() in {"windows", "darwin"}:
logger.warning("For Windows and MacOS we forbid shared mem function temporarily.")
return
if not isinstance(enable, bool):
raise TypeError("enable must be of type bool.")
if enable:
# For Windows and MacOS we forbid shared mem function temporarily
if platform.system().lower() in {"windows", "darwin"}:
logger.warning("For Windows and MacOS we forbid shared mem function temporarily.")
return
logger.warning("The shared memory is on, multiprocessing performance will be improved. "
"Note: the required shared memory can't exceeds 80% of the available shared memory.")
_config.set_enable_shared_mem(enable)

View File

@ -66,7 +66,7 @@ from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterato
ITERATORS_LIST, _unset_iterator_cleanup
from .queue import _SharedQueue
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_device_send, check_take, check_project, \
check_rename, check_device_send, check_take, check_output_shape, check_project, \
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, deprecated
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
@ -1576,17 +1576,27 @@ class Dataset:
runtime_getter[2].notify_watchdog()
return self._col_names
@check_output_shape
def output_shapes(self, estimate=False):
"""
Get the shapes of output data. If `estimate` is False, will return the shape of first data row.
Otherwise, will iterate the whole dataset and return the estimated shape of data row, where dynamic
shape marked is marked as -1.
Get the shapes of output data.
Args:
estimate (bool): If `estimate` is False, will return the shapes of first data row.
Otherwise, will iterate the whole dataset and return the estimated shapes of data row,
where dynamic shape is marked as -1 (used in dynamic data shapes scenario).
Returns:
list, list of shapes of each column.
Examples:
>>> # dataset is an instance object of Dataset
>>> import numpy as np
>>>
>>> def generator1():
... for i in range(1, 100):
... yield np.ones((16, i, 83)), np.array(i)
>>>
>>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
>>> output_shapes = dataset.output_shapes()
"""
# cache single shape

View File

@ -1524,6 +1524,20 @@ def check_rename(method):
return new_method
def check_output_shape(method):
"""check the input arguments of output_shape."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
estimate = param_dict.get('estimate')
type_check(estimate, (bool,), "estimate")
return method(self, *args, **kwargs)
return new_method
def check_project(method):
"""check the input arguments of project."""

View File

@ -258,6 +258,10 @@ def test_output_shapes_exception():
_ = dataset.output_shapes(estimate=True)
assert "Inconsistent shapes, expect same shape for each data row" in str(info.value)
with pytest.raises(TypeError) as info:
dataset = ds.GeneratorDataset(generator3, ["data1", "data2", "data3"])
_ = dataset.output_shapes(estimate=1)
if __name__ == "__main__":
test_get_dynamic_min_max_shapes_0()