forked from mindspore-Ecosystem/mindspore
!35081 fix-output_shape_doc
Merge pull request !35081 from luoyang/fix-output_shape_doc
This commit is contained in:
commit
57f66960ae
|
@ -203,6 +203,11 @@
|
|||
|
||||
获取数据集对象中每列数据的shape。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **estimate** (bool) - 如果 `estimate` 为 False,将返回数据集第一条数据的shape。
|
||||
否则将遍历整个数据集以获取数据集的真实shape信息,其中动态变化的维度将被标记为-1(可用于动态shape数据集场景)。
|
||||
|
||||
**返回:**
|
||||
|
||||
list,每列数据的shape列表。
|
||||
|
|
|
@ -641,14 +641,13 @@ def set_enable_shared_mem(enable):
|
|||
>>> # Enable shared memory feature to improve the performance of Python multiprocessing.
|
||||
>>> ds.config.set_enable_shared_mem(True)
|
||||
"""
|
||||
# For Windows and MacOS we forbid shared mem function temporarily
|
||||
if platform.system().lower() in {"windows", "darwin"}:
|
||||
logger.warning("For Windows and MacOS we forbid shared mem function temporarily.")
|
||||
return
|
||||
|
||||
if not isinstance(enable, bool):
|
||||
raise TypeError("enable must be of type bool.")
|
||||
if enable:
|
||||
# For Windows and MacOS we forbid shared mem function temporarily
|
||||
if platform.system().lower() in {"windows", "darwin"}:
|
||||
logger.warning("For Windows and MacOS we forbid shared mem function temporarily.")
|
||||
return
|
||||
logger.warning("The shared memory is on, multiprocessing performance will be improved. "
|
||||
"Note: the required shared memory can't exceeds 80% of the available shared memory.")
|
||||
_config.set_enable_shared_mem(enable)
|
||||
|
|
|
@ -66,7 +66,7 @@ from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterato
|
|||
ITERATORS_LIST, _unset_iterator_cleanup
|
||||
from .queue import _SharedQueue
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, check_device_send, check_take, check_project, \
|
||||
check_rename, check_device_send, check_take, check_output_shape, check_project, \
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
|
||||
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, deprecated
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
|
@ -1576,17 +1576,27 @@ class Dataset:
|
|||
runtime_getter[2].notify_watchdog()
|
||||
return self._col_names
|
||||
|
||||
@check_output_shape
|
||||
def output_shapes(self, estimate=False):
|
||||
"""
|
||||
Get the shapes of output data. If `estimate` is False, will return the shape of first data row.
|
||||
Otherwise, will iterate the whole dataset and return the estimated shape of data row, where dynamic
|
||||
shape marked is marked as -1.
|
||||
Get the shapes of output data.
|
||||
|
||||
Args:
|
||||
estimate (bool): If `estimate` is False, will return the shapes of first data row.
|
||||
Otherwise, will iterate the whole dataset and return the estimated shapes of data row,
|
||||
where dynamic shape is marked as -1 (used in dynamic data shapes scenario).
|
||||
|
||||
Returns:
|
||||
list, list of shapes of each column.
|
||||
|
||||
Examples:
|
||||
>>> # dataset is an instance object of Dataset
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> def generator1():
|
||||
... for i in range(1, 100):
|
||||
... yield np.ones((16, i, 83)), np.array(i)
|
||||
>>>
|
||||
>>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
|
||||
>>> output_shapes = dataset.output_shapes()
|
||||
"""
|
||||
# cache single shape
|
||||
|
|
|
@ -1524,6 +1524,20 @@ def check_rename(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_output_shape(method):
|
||||
"""check the input arguments of output_shape."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
estimate = param_dict.get('estimate')
|
||||
type_check(estimate, (bool,), "estimate")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_project(method):
|
||||
"""check the input arguments of project."""
|
||||
|
||||
|
|
|
@ -258,6 +258,10 @@ def test_output_shapes_exception():
|
|||
_ = dataset.output_shapes(estimate=True)
|
||||
assert "Inconsistent shapes, expect same shape for each data row" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
dataset = ds.GeneratorDataset(generator3, ["data1", "data2", "data3"])
|
||||
_ = dataset.output_shapes(estimate=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_dynamic_min_max_shapes_0()
|
||||
|
|
Loading…
Reference in New Issue