forked from mindspore-Ecosystem/mindspore
add do_copy parameter for create_dict_iterator
This commit is contained in:
parent
b562edf89e
commit
6b3e1b4896
|
@ -1,13 +1,14 @@
|
|||
mindspore.dataset.Dataset.create_dict_iterator
|
||||
==============================================
|
||||
|
||||
.. py:method:: mindspore.dataset.Dataset.create_dict_iterator(num_epochs=-1, output_numpy=False)
|
||||
.. py:method:: mindspore.dataset.Dataset.create_dict_iterator(num_epochs=-1, output_numpy=False, do_copy=True)
|
||||
|
||||
基于数据集对象创建迭代器。输出的数据为字典类型。
|
||||
|
||||
参数:
|
||||
- **num_epochs** (int, 可选) - 迭代器可以迭代的最大次数。默认值:-1,迭代器可以迭代无限次。
|
||||
- **output_numpy** (bool, 可选) - 输出的数据是否转为NumPy类型。如果为False,迭代器输出的每列数据类型为MindSpore.Tensor,否则为NumPy。默认值:False。
|
||||
- **do_copy** (bool, 可选) - 当参数 `output_numpy` 为False,即输出数据类型为mindspore.Tensor时,可以将此参数指定为False以减少拷贝,获得更好的性能。默认值:True。
|
||||
|
||||
返回:
|
||||
DictIterator,基于数据集对象创建的字典迭代器。
|
||||
|
|
|
@ -1441,7 +1441,7 @@ class Dataset:
|
|||
Default: -1, iterator can be iterated infinite number of epochs.
|
||||
output_numpy (bool, optional): Whether or not to output NumPy datatype.
|
||||
If output_numpy=False, iterator will output MSTensor. Default: False.
|
||||
do_copy (bool, optional): when output data type is mindspore.Tensor,
|
||||
do_copy (bool, optional): When output data type is mindspore.Tensor,
|
||||
use this param to select the conversion method, only take False for better performance. Default: True.
|
||||
|
||||
Returns:
|
||||
|
@ -1464,7 +1464,7 @@ class Dataset:
|
|||
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
|
||||
|
||||
@check_dict_iterator
|
||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True):
|
||||
"""
|
||||
Create an iterator over the dataset. The data retrieved will be a dictionary datatype.
|
||||
|
||||
|
@ -1473,6 +1473,8 @@ class Dataset:
|
|||
Default: -1, iterator can be iterated infinite number of epochs.
|
||||
output_numpy (bool, optional): Whether or not to output NumPy datatype,
|
||||
if output_numpy=False, iterator will output MSTensor. Default: False.
|
||||
do_copy (bool, optional): When output data type is mindspore.Tensor,
|
||||
use this param to select the conversion method, only take False for better performance. Default: True.
|
||||
|
||||
Returns:
|
||||
Iterator, dictionary iterator over the dataset.
|
||||
|
@ -1491,7 +1493,7 @@ class Dataset:
|
|||
|
||||
if Dataset._noop_mode():
|
||||
return DummyIterator(self, 'dict', output_numpy)
|
||||
return DictIterator(self, num_epochs, output_numpy)
|
||||
return DictIterator(self, num_epochs, output_numpy, do_copy)
|
||||
|
||||
def __iter__(self):
|
||||
"""Create an iterator over the dataset."""
|
||||
|
|
|
@ -996,7 +996,7 @@ def check_dict_iterator(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
[num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
if num_epochs is not None:
|
||||
|
|
|
@ -49,6 +49,7 @@ def test_iterator_create_tuple_numpy():
|
|||
check(COLUMNS[7:8])
|
||||
check(COLUMNS[0:2:8])
|
||||
|
||||
|
||||
def test_iterator_create_dict_mstensor():
|
||||
"""
|
||||
Feature: Iterator
|
||||
|
@ -71,6 +72,16 @@ def test_iterator_create_dict_mstensor():
|
|||
i += 1
|
||||
assert i == 64
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, do_copy=False):
|
||||
golden = np.array([i], dtype=np.float32)
|
||||
np.testing.assert_array_equal(item["data"].asnumpy(), golden)
|
||||
assert isinstance(item["data"], Tensor)
|
||||
assert item["data"].dtype == mstype.float32
|
||||
i += 1
|
||||
assert i == 64
|
||||
|
||||
|
||||
def test_iterator_create_tuple_mstensor():
|
||||
"""
|
||||
Feature: Iterator
|
||||
|
@ -93,6 +104,15 @@ def test_iterator_create_tuple_mstensor():
|
|||
i += 1
|
||||
assert i == 64
|
||||
|
||||
i = 0
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, do_copy=False):
|
||||
golden = np.array([i], dtype=np.float32)
|
||||
np.testing.assert_array_equal(item[0].asnumpy(), golden)
|
||||
assert isinstance(item[0], Tensor)
|
||||
assert item[0].dtype == mstype.float32
|
||||
i += 1
|
||||
assert i == 64
|
||||
|
||||
|
||||
def test_iterator_weak_ref():
|
||||
"""
|
||||
|
@ -135,6 +155,7 @@ def test_iterator_weak_ref():
|
|||
|
||||
_cleanup()
|
||||
|
||||
|
||||
def test_iterator_exception():
|
||||
"""
|
||||
Feature: Iterator
|
||||
|
|
Loading…
Reference in New Issue