modify param of data sink

This commit is contained in:
yuzhenhua 2022-10-28 14:49:13 +08:00
parent bafedab775
commit 59338f30e8
3 changed files with 15 additions and 16 deletions

View File

@ -1,7 +1,7 @@
mindspore.data_sink
===================
.. py:function:: mindspore.data_sink(fn, dataset, steps, sink_size=1, jit_config=None, input_signature=None)
.. py:function:: mindspore.data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None)
对输入的函数封装生成一个新的函数。
@ -10,7 +10,6 @@ mindspore.data_sink
参数:
- **fn** (Function) - 将与数据集一起运行的函数。
- **dataset** (Dataset) - 训练数据集迭代器。数据集可以由数据集生成器API在 :class:`mindspore.dataset` 中生成,例如 :class:`mindspore.dataset.ImageFolderDataset`
- **steps** (int) - 总的运行次数。 `steps` 必须为正整数。
- **sink_size** (int) - 控制每次下沉的数据执行次数。 `sink_size` 必须为正整数。默认值1。
- **jit_config** (JitConfig) - 编译时所使用的JitConfig配置项详细可参考 :class:`mindspore.JitConfig` 。默认值None表示以PyNative模式运行。
- **input_signature** (Union[Tensor, List or Tuple of Tensors]) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值None。
@ -19,4 +18,4 @@ mindspore.data_sink
函数,该生成的函数会以数据下沉模式执行。
异常:
- **ValueError** - 如果 `steps` 或者 `sink_size` 不是正整数。
- **ValueError** - 如果 `sink_size` 不是正整数。

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""Data sink help for minddata dataset"""
import math
from functools import wraps
import mindspore.ops as ops
from mindspore import context
@ -26,7 +25,7 @@ from mindspore._c_expression import _set_dataset_mode_config
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes, _get_pipeline_stages
def _init_sink_dataset(dataset, steps, sink_size, input_signature):
def _init_sink_dataset(dataset, sink_size, input_signature):
"""
Initialize data sinking
"""
@ -44,8 +43,7 @@ def _init_sink_dataset(dataset, steps, sink_size, input_signature):
dataset.__transfer_dataset__ = transfer_dataset
# send data
sink_count = math.ceil(steps/dataset_size)
transfer_dataset.send(sink_count)
transfer_dataset.send(-1)
# create GetNext op
if input_signature is not None:
@ -125,7 +123,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
return dst_sink_fun
def data_sink(fn, dataset, steps, sink_size=1, jit_config=None, input_signature=None):
def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
"""
A wrapper function to generate a function for the input function.
@ -133,7 +131,6 @@ def data_sink(fn, dataset, steps, sink_size=1, jit_config=None, input_signature=
fn (Function): The Python function that will be run with dataset.
dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in
:class:`mindspore.dataset`, such as :class:`mindspore.dataset.ImageFolderDataset`.
steps (int): The total running steps. `steps` must be positive integer.
sink_size (int): Control the amount of data in each sink. `sink_size` must be positive integer. Default: 1.
jit_config (JitConfig): Controls the execution mode(Graph mode/PyNative mode) of the generated function, and Jit
config for compile. Default: None, means running in PyNative mode.
@ -147,7 +144,7 @@ def data_sink(fn, dataset, steps, sink_size=1, jit_config=None, input_signature=
Function, the generated function will be executed in data sinking mode.
Raises:
ValueError: If `steps` or `sink_size` is not positive integer.
ValueError: If `sink_size` is not positive integer.
Supported Platforms:
``Ascend`` ``GPU``
@ -164,7 +161,7 @@ def data_sink(fn, dataset, steps, sink_size=1, jit_config=None, input_signature=
... out = x + y
... return out
>>>
>>> sink_process = ms.train.data_sink(func_net, dataset, steps=2, sink_size=1)
>>> sink_process = ms.train.data_sink(func_net, dataset, sink_size=1)
>>> for _ in range(2):
... out = sink_process()
... print(out)
@ -172,16 +169,19 @@ def data_sink(fn, dataset, steps, sink_size=1, jit_config=None, input_signature=
2
"""
if sink_size <= 0 or steps <= 0:
if sink_size <= 0:
raise ValueError(
f"The 'steps' and 'sink_size' must be positive, but got steps {steps} sink_size {sink_size}.")
f"The 'sink_size' must be positive, but got sink_size {sink_size}.")
if context.get_context('device_target') not in ('Ascend', 'GPU'):
raise ValueError(
f"Data sinking supports ascend or gpu device target, "
f"but device target is {context.get_context('device_target')}.")
ori_next_op, is_info_queue = _init_sink_dataset(dataset, steps, sink_size, input_signature)
loop = sink_size
if jit_config is not None:
loop = 1
ori_next_op, is_info_queue = _init_sink_dataset(dataset, loop, input_signature)
@wraps(fn)
def sink_process(*args, **kwargs):

View File

@ -84,7 +84,7 @@ def test_sink():
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
dataset = ds.NumpySlicesDataset(data=data)
jitconfig = JitConfig(jit_level="O1", task_sink=True)
sink_process = ms.train.data_sink(dense_func, dataset, steps=2, sink_size=4, jit_config=jitconfig)
sink_process = ms.train.data_sink(dense_func, dataset, sink_size=4, jit_config=jitconfig)
_ = sink_process()
@ -111,7 +111,7 @@ def test_sink_with_grad():
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
dataset = ds.NumpySlicesDataset(data=data)
jitconfig = JitConfig(jit_level="O1", task_sink=True)
sink_process = ms.train.data_sink(train_step, dataset, steps=2, sink_size=4, jit_config=jitconfig)
sink_process = ms.train.data_sink(train_step, dataset, sink_size=4, jit_config=jitconfig)
_ = sink_process()