!47570 add do_copy parameter for create_dict_iterator

Merge pull request !47570 from guozhijian/add_do_copy_for_dict_iterator
This commit is contained in:
i-robot 2023-01-07 02:27:35 +00:00 committed by Gitee
commit 5a0208b616
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 29 additions and 5 deletions

View File

@ -1,13 +1,14 @@
mindspore.dataset.Dataset.create_dict_iterator
==============================================
.. py:method:: mindspore.dataset.Dataset.create_dict_iterator(num_epochs=-1, output_numpy=False)
.. py:method:: mindspore.dataset.Dataset.create_dict_iterator(num_epochs=-1, output_numpy=False, do_copy=True)
基于数据集对象创建迭代器。输出的数据为字典类型。
参数:
- **num_epochs** (int, 可选) - 迭代器可以迭代的最大次数。默认值:-1迭代器可以迭代无限次。
- **output_numpy** (bool, 可选) - 输出的数据是否转为NumPy类型。如果为False迭代器输出的每列数据类型为MindSpore.Tensor否则为NumPy。默认值False。
- **do_copy** (bool, 可选) - 当参数 `output_numpy` 为False即输出数据类型为mindspore.Tensor时可以将此参数指定为False以减少拷贝获得更好的性能。默认值True。
返回:
DictIterator基于数据集对象创建的字典迭代器。

View File

@ -1441,7 +1441,7 @@ class Dataset:
Default: -1, iterator can be iterated infinite number of epochs.
output_numpy (bool, optional): Whether or not to output NumPy datatype.
If output_numpy=False, iterator will output MSTensor. Default: False.
do_copy (bool, optional): when output data type is mindspore.Tensor,
do_copy (bool, optional): When output data type is mindspore.Tensor,
use this param to select the conversion method, only take False for better performance. Default: True.
Returns:
@ -1464,7 +1464,7 @@ class Dataset:
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
@check_dict_iterator
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True):
"""
Create an iterator over the dataset. The data retrieved will be a dictionary datatype.
@ -1473,6 +1473,8 @@ class Dataset:
Default: -1, iterator can be iterated infinite number of epochs.
output_numpy (bool, optional): Whether or not to output NumPy datatype,
if output_numpy=False, iterator will output MSTensor. Default: False.
do_copy (bool, optional): When output data type is mindspore.Tensor,
use this param to select the conversion method, only take False for better performance. Default: True.
Returns:
Iterator, dictionary iterator over the dataset.
@ -1491,7 +1493,7 @@ class Dataset:
if Dataset._noop_mode():
return DummyIterator(self, 'dict', output_numpy)
return DictIterator(self, num_epochs, output_numpy)
return DictIterator(self, num_epochs, output_numpy, do_copy)
def __iter__(self):
"""Create an iterator over the dataset."""

View File

@ -996,7 +996,7 @@ def check_dict_iterator(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
[num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
if num_epochs is not None:

View File

@ -49,6 +49,7 @@ def test_iterator_create_tuple_numpy():
check(COLUMNS[7:8])
check(COLUMNS[0:2:8])
def test_iterator_create_dict_mstensor():
"""
Feature: Iterator
@ -71,6 +72,16 @@ def test_iterator_create_dict_mstensor():
i += 1
assert i == 64
i = 0
for item in data1.create_dict_iterator(num_epochs=1, do_copy=False):
golden = np.array([i], dtype=np.float32)
np.testing.assert_array_equal(item["data"].asnumpy(), golden)
assert isinstance(item["data"], Tensor)
assert item["data"].dtype == mstype.float32
i += 1
assert i == 64
def test_iterator_create_tuple_mstensor():
"""
Feature: Iterator
@ -93,6 +104,15 @@ def test_iterator_create_tuple_mstensor():
i += 1
assert i == 64
i = 0
for item in data1.create_tuple_iterator(num_epochs=1, do_copy=False):
golden = np.array([i], dtype=np.float32)
np.testing.assert_array_equal(item[0].asnumpy(), golden)
assert isinstance(item[0], Tensor)
assert item[0].dtype == mstype.float32
i += 1
assert i == 64
def test_iterator_weak_ref():
"""
@ -135,6 +155,7 @@ def test_iterator_weak_ref():
_cleanup()
def test_iterator_exception():
"""
Feature: Iterator