!44944 del set_dynamic_columns operation

Merge pull request !44944 from 刘勇琪/master
This commit is contained in:
i-robot 2022-11-03 09:46:39 +00:00 committed by Gitee
commit 174f46f24d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 40 additions and 371 deletions

View File

@ -1,9 +0,0 @@
mindspore.dataset.Dataset.dynamic_min_max_shapes
================================================
.. py:method:: mindspore.dataset.Dataset.dynamic_min_max_shapes()
当数据集对象中的数据shape不唯一动态shape获取数据的最小shape和最大shape。
返回:
两个列表代表最小shape和最大shape每个列表中的shape按照数据列的顺序排列。

View File

@ -1,9 +0,0 @@
mindspore.dataset.Dataset.set_dynamic_columns
=============================================
.. py:method:: mindspore.dataset.Dataset.set_dynamic_columns(columns=None)
设置数据集的动态shape信息需要在定义好完整的数据处理管道后进行设置。
参数:
- **columns** (dict) - 包含数据集中每列shape信息的字典。shape[i]为 `None` 表示shape[i]的数据长度是动态的。

View File

@ -16,7 +16,6 @@
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.save
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -54,7 +53,6 @@ Batch批操作
:nosignatures:
:template: classtemplate.rst
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names

View File

@ -16,7 +16,6 @@
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.save
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -56,7 +55,6 @@ Batch批操作
:nosignatures:
:template: classtemplate.rst
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names

View File

@ -16,7 +16,6 @@
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.save
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -54,7 +53,6 @@ Batch批操作
:nosignatures:
:template: classtemplate.rst
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names

View File

@ -20,10 +20,6 @@ mindspore.DatasetHelper
在epoch开始时继续向设备发送数据。
.. py:method:: dynamic_min_max_shapes()
返回动态数据的形状(shape)范围(最小形状(shape),最大形状(shape))。
.. py:method:: get_data_info()
下沉模式下,获取当前批次数据的类型和形状(shape)。通常在数据形状(shape)动态变化的场景使用。

View File

@ -27,7 +27,6 @@ Pre-processing Operation
mindspore.dataset.Dataset.rename
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -65,7 +64,6 @@ Attribute
:nosignatures:
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names
@ -124,7 +122,6 @@ Pre-processing Operation
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.save
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -162,7 +159,6 @@ Attribute
:nosignatures:
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names
@ -222,7 +218,6 @@ Pre-processing Operation
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.save
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -260,7 +255,6 @@ Attribute
:nosignatures:
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names
@ -320,7 +314,6 @@ Pre-processing Operation
mindspore.dataset.Dataset.repeat
mindspore.dataset.Dataset.reset
mindspore.dataset.Dataset.save
mindspore.dataset.Dataset.set_dynamic_columns
mindspore.dataset.Dataset.shuffle
mindspore.dataset.Dataset.skip
mindspore.dataset.Dataset.split
@ -359,7 +352,6 @@ Attribute
:nosignatures:
mindspore.dataset.Dataset.dynamic_min_max_shapes
mindspore.dataset.Dataset.get_batch_size
mindspore.dataset.Dataset.get_class_indexing
mindspore.dataset.Dataset.get_col_names

View File

@ -45,7 +45,6 @@ import copy
import weakref
import platform
import psutil
import numpy as np
import mindspore._c_dataengine as cde
from mindspore._c_expression import typing
@ -68,8 +67,7 @@ from .queue import _SharedQueue, _Queue
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_device_send, check_take, check_output_shape, check_project, \
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch, \
deprecated
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_enable_watchdog, get_seed, set_seed
from ..core.datatypes import mstype_to_detype
@ -1584,11 +1582,6 @@ class Dataset:
if estimate and self.estimated_output_shapes is not None:
return self.estimated_output_shapes
# if use set_dynamic_column, the `estimate` does not work, but they get the same result
if self.dynamic_setting[0]:
self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
return self.saved_output_shapes
# We have a hang problem when two-level pipeline with multiprocessing, we need to extend the life cycle
# of runtime_context. We found this hang problem only occur on output_types and output_shapes.
runtime_getter = self._init_tree_getters()
@ -1651,52 +1644,6 @@ class Dataset:
return self.dataset_size
@deprecated("1.5")
def set_dynamic_columns(self, columns=None):
"""
Set dynamic shape information of source data, it should be set after the pipeline is defined.
Args:
columns (dict): A dict contains shape information of each column in dataset.
The value of shape[i] is :py:obj:`None` indicates that the data length of shape[i] is dynamic.
Examples:
>>> import numpy as np
>>>
>>> def generator1():
... for i in range(1, 100):
... yield np.ones((16, i, 83)), np.array(i)
>>>
>>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
>>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
"""
if not isinstance(columns, dict):
raise TypeError("Pass a dict to set dynamic shape, example: {\"data1\": [16, None, 256]}")
self.dynamic_setting[0] = True
self.dynamic_setting[1] = columns
def dynamic_min_max_shapes(self):
"""
Get minimum and maximum data length of dynamic source data, for dynamic graph compilation.
Returns:
lists, min_shapes, max_shapes of source data.
Examples:
>>> import numpy as np
>>>
>>> def generator1():
... for i in range(1, 100):
... yield np.ones((16, i, 83)), np.array(i)
>>>
>>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
>>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
>>> min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
"""
if self.saved_min_shapes is None or self.saved_max_shapes is None:
self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
return self.saved_min_shapes, self.saved_max_shapes
@staticmethod
def __check_dynamic_column_name(dynamic_columns, dataset_columns):
for column in dynamic_columns:
@ -1714,73 +1661,6 @@ class Dataset:
if dynamic_columns[col][dim] is not None and dynamic_columns[col][dim] != data[col].shape[dim]:
raise RuntimeError(shape_mismatch)
def _dynamic_output_shapes(self):
"""
Get dynamic information of source data.
Returns:
lists, dynamic_shapes, min_shapes, max_shapes of source data.
"""
if not self.dynamic_setting[1]:
raise RuntimeError("dynamic_columns is not set, call set_dynamic_columns() by final Dataset Op.")
if self.saved_output_shapes is not None and self.saved_min_shapes is not None and \
self.saved_max_shapes is not None:
return self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes
logger.warning("Calculating dynamic shape of input data, this will take a few minutes...")
# Assume data1 shape is dynamic, data2 shape is fix
dynamic_columns = self.dynamic_setting[1]
# ["data1", "data2"]
dataset_columns = self.get_col_names()
Dataset.__check_dynamic_column_name(dynamic_columns, dataset_columns)
# Shape[1] of data1 is variable
# {"data1": {(batch_size, 100, feat_len), (16, 200, 83)}, "data2": {(batch_size, feat_len)}}
column_shape_set = {col: set() for col in dataset_columns}
dataset_size_counter = 0
for data in self.create_dict_iterator(num_epochs=1, output_numpy=True):
dataset_size_counter += 1
for col in data.keys():
if col in dynamic_columns:
Dataset.__check_dynamic_column_shape(data, col, dynamic_columns)
column_shape_set[col].add(tuple(data[col].shape))
# we get dataset_size after dryrun
self.dataset_size = dataset_size_counter
min_shapes, max_shapes, dynamic_shapes = list(), list(), list()
for col, shape_set in column_shape_set.items():
if len(shape_set) > 1:
if col not in dynamic_columns:
raise RuntimeError("column [" + col + "] has dynamic shape but not set by set_dynamic_columns()" +
", shapes of [" + col + "]: " + str(list(shape_set)))
shape_npy = np.array(list(shape_set))
max_shape = shape_npy.max(axis=0)
min_shape = shape_npy.min(axis=0)
# Set min shape to 1 due to unknown shuffle
min_shape = np.where(np.equal(dynamic_columns[col], None), 1, min_shape)
# Set dynamic dim to -1 for ME
dynamic_shape = np.where(np.equal(dynamic_columns[col], None), -1, dynamic_columns[col])
max_shapes.append(max_shape.tolist())
min_shapes.append(min_shape.tolist())
dynamic_shapes.append(dynamic_shape.tolist())
else:
# Also append fix shape to keep order of column shape
fix_shape = list(list(shape_set)[0])
max_shapes.append(fix_shape)
min_shapes.append(fix_shape)
dynamic_shapes.append(fix_shape)
if col in dynamic_columns:
logger.warning("column [" + col + "] has no dynamic shape but set by set_dynamic_columns()")
# Set min shape to 1 due to unknown shuffle
min_shapes[-1] = np.where(np.equal(dynamic_columns[col], None), 1, fix_shape).tolist()
# Set dynamic dim to -1 for ME
dynamic_shapes[-1] = np.where(np.equal(dynamic_columns[col], None), -1, fix_shape).tolist()
return dynamic_shapes, min_shapes, max_shapes
def num_classes(self):
"""
Get the number of classes in a dataset.

View File

@ -20,9 +20,8 @@ import inspect as ins
import os
import re
from functools import wraps
import numpy as np
from mindspore import log
from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
@ -3017,29 +3016,6 @@ def check_multi30k_dataset(method):
return new_method
def deprecated(version, substitute=None):
"""deprecated warning
Args:
version (str): version that the operation or function is deprecated.
substitute (str): the substitute name for deprecated operation or function.
"""
def decorate(func):
def wrapper(*args, **kwargs):
name = func.__name__
message = f"'{name}' is deprecated from version {version} and will be removed in a future version. "
if substitute:
message += f"Use '{substitute}' instead."
log.warning(message)
ret = func(*args, **kwargs)
return ret
return wrapper
return decorate
def check_obsminddataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(OBSMindDataset)."""

View File

@ -85,7 +85,7 @@ class _DataWrapper(nn.Cell):
dataset channel 'queue_name' and performs the forward computation.
"""
def __init__(self, network, dataset_types, dataset_shapes, queue_name, min_shapes=None, max_shapes=None):
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(_DataWrapper, self).__init__(
auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
@ -94,11 +94,6 @@ class _DataWrapper(nn.Cell):
self.add_flags(**flags)
self.get_next = P.GetNext(
dataset_types, dataset_shapes, len(dataset_types), queue_name)
if min_shapes is not None and max_shapes is not None:
Validator.check_value_type("min_shapes", min_shapes, [list, tuple])
Validator.check_value_type("max_shapes", max_shapes, [list, tuple])
self.get_next.add_prim_attr("min_shapes", min_shapes)
self.get_next.add_prim_attr("max_shapes", max_shapes)
self.network = network
def construct(self):
@ -106,11 +101,10 @@ class _DataWrapper(nn.Cell):
return self.network(*outputs)
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name,
min_shapes=None, max_shapes=None):
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name):
if not isinstance(network, _DataWrapper):
network = _DataWrapper(
network, dataset_types, dataset_shapes, queue_name, min_shapes, max_shapes)
network, dataset_types, dataset_shapes, queue_name)
return network
@ -126,18 +120,11 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
Generate new network with network and dataset info.
"""
dataset_types, dataset_shapes = dataset_helper.types_shapes()
if not _has_dynamic_shape(dataset_shapes):
(min_shapes, max_shapes) = (None, None)
else:
(min_shapes, max_shapes) = dataset_helper.dynamic_min_max_shapes()
if network.get_inputs() and None not in network.get_inputs():
_check_inputs(network.get_inputs(), dataset_shapes, dataset_types)
min_shapes, max_shapes = None, None
elif context.get_context("mode") == context.PYNATIVE_MODE:
dataset_shapes = tuple([(-2,)] * len(dataset_shapes))
min_shapes, max_shapes = None, None
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
queue_name, min_shapes, max_shapes)
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
return network
@ -418,28 +405,6 @@ class DatasetHelper:
"""
return self.iter.get_data_info()
def dynamic_min_max_shapes(self):
"""
Return the minimum and maximum data length of dynamic source dataset.
Examples:
>>> import mindspore as ms
>>> import numpy as np
>>>
>>> # Define a dataset pipeline
>>> def generator():
... for i in range(5):
... yield (np.ones((32, i)),)
>>>
>>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
>>> # config dynamic shape
>>> train_dataset.set_dynamic_columns(columns={"data": [32, None]})
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>>
>>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes()
"""
return self.iter.dynamic_min_max_shapes()
class _DatasetIter:
"""Base iter for dataset helper"""
@ -476,7 +441,6 @@ class _DatasetIter:
self.release = dataset.__transfer_dataset__.release
self.continue_send = dataset.__transfer_dataset__.continue_send
self.get_data_info = dataset.__transfer_dataset__.get_data_info
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
if hasattr(dataset.__transfer_dataset__, "_reset"):
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212

View File

@ -74,9 +74,6 @@ class MindData:
def get_data_info(self):
pass
def dynamic_min_max_shapes(self):
pass
def __len__(self):
return self._size

View File

@ -61,7 +61,9 @@ def run_async_dump(test_name):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = Net()
dataset = ds.GeneratorDataset(dataset_generator, ['data1', 'data2'])
dataset.set_dynamic_columns(columns={'data1': [32, None], 'data2': [32, None]})
t0 = Tensor(dtype=mindspore.float32, shape=[32, None])
t1 = Tensor(dtype=mindspore.float32, shape=[32, None])
network.set_inputs(t0, t1)
model = Model(network)
with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
dump_path = os.path.join(tmp_dir, 'async_dump')
@ -99,7 +101,9 @@ def run_e2e_dump():
return
network = Net()
dataset = ds.GeneratorDataset(dataset_generator, ['data1', 'data2'])
dataset.set_dynamic_columns(columns={'data1': [32, None], 'data2': [32, None]})
t0 = Tensor(dtype=mindspore.float32, shape=[32, None])
t1 = Tensor(dtype=mindspore.float32, shape=[32, None])
network.set_inputs(t0, t1)
model = Model(network)
with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
dump_path = os.path.join(tmp_dir, 'e2e_dump')

View File

@ -17,13 +17,15 @@ import glob
import tempfile
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import Model
from mindspore import Profiler
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.dataset as ds
from mindspore import Profiler
from mindspore import Model
from tests.security_utils import security_off_wrap
@ -132,7 +134,9 @@ def test_shape():
network = NetWork()
profiler = Profiler(output_path=tmpdir)
dataset = ds.GeneratorDataset(dataset_generator, ["data1", "data2"])
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None]})
t0 = Tensor(dtype=mindspore.float32, shape=[32, None])
t1 = Tensor(dtype=mindspore.float32, shape=[32, None])
network.set_inputs(t0, t1)
model = Model(network)
model.train(1, dataset, dataset_sink_mode=True)
profiler.analyse()

View File

@ -23,28 +23,21 @@ def generator0():
yield (np.ones((32, i)), np.zeros((16, i, i, 3)), np.ones((i)))
def test_get_dynamic_min_max_shapes_0():
def test_output_shapes_0():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with dynamic shape columns
Feature: Test output_shapes
Description: Test output_shapes with data of generator0
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes with dynamic shape columns.")
logger.info("Test output_shapes with data of generator0.")
dataset = ds.GeneratorDataset(generator0, ["data1", "data2", "data3"])
# new api
estimate_dynamic_shapes = dataset.output_shapes(estimate=True)
# old api
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [16, None, None, 3], "data3": [None]})
min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
dynamic_shapes = dataset.output_shapes()
# check result
np.testing.assert_array_equal(min_shapes, [[32, 1], [16, 1, 1, 3], [1]])
np.testing.assert_array_equal(max_shapes, [[32, 69], [16, 69, 69, 3], [69]])
np.testing.assert_array_equal(dynamic_shapes, [[32, -1], [16, -1, -1, 3], [-1]])
np.testing.assert_array_equal(dynamic_shapes, [[32, 50], [16, 50, 50, 3], [50]])
np.testing.assert_array_equal(estimate_dynamic_shapes, [[32, None], [16, None, None, 3], [None]])
@ -53,66 +46,37 @@ def generator1():
yield (np.ones((16, i, 83)), np.array((i)))
def test_get_dynamic_min_max_shapes_1():
def test_output_shapes_1():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with dynamic shape column and fix shape column
Feature: Test output_shapes
Description: Test output_shapes with data of generator1
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes with dynamic shape column and fix shape column.")
logger.info("Test output_shapes with data of generator1.")
dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
# new api
estimate_dynamic_shapes = dataset.output_shapes(estimate=True)
# old api
dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
dynamic_shapes = dataset.output_shapes()
min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
# check result
# raise a warning to tell user "data2" is not dynamic
np.testing.assert_array_equal(min_shapes, [[16, 1, 83], []])
np.testing.assert_array_equal(max_shapes, [[16, 99, 83], []])
np.testing.assert_array_equal(dynamic_shapes, [[16, -1, 83], []])
np.testing.assert_array_equal(dynamic_shapes, [[16, 1, 83], []])
np.testing.assert_array_equal(estimate_dynamic_shapes, [[16, None, 83], []])
def test_get_dynamic_min_max_shapes_2():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with setting all columns to dynamic
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes with setting all columns to dynamic.")
dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
# config all dims have dynamic shape
dataset.set_dynamic_columns(columns={"data1": [None, None, None]})
dynamic_shapes = dataset.output_shapes()
min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
# check result
# Although shape[0] of data1 is fix in given data, user think it is dynamic, so shape[0] is dynamic
np.testing.assert_array_equal(min_shapes, [[1, 1, 1], []])
np.testing.assert_array_equal(max_shapes, [[16, 99, 83], []])
np.testing.assert_array_equal(dynamic_shapes, [[-1, -1, -1], []])
def generator2():
for i in range(80, 100):
yield (np.ones((16, i, 83)), np.ones((5, 5)))
def test_get_dynamic_min_max_shapes_3():
def test_output_shapes_2():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with dynamic shape columns
Feature: Test output_shapes
Description: Test output_shapes with data of generator2
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes only config dynamic column.")
logger.info("Test output_shapes with data of generator2.")
dataset = ds.GeneratorDataset(generator2, ["data1", "data2"])
@ -120,51 +84,21 @@ def test_get_dynamic_min_max_shapes_3():
estimate_dynamic_shapes = dataset.output_shapes(estimate=True)
# old api
dataset.set_dynamic_columns(columns={"data1": [16, None, 83]})
dynamic_shapes = dataset.output_shapes()
min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
# check result
# column with fixed shape will also be appended to shapes list
np.testing.assert_array_equal(min_shapes, [[16, 1, 83], [5, 5]])
np.testing.assert_array_equal(max_shapes, [[16, 99, 83], [5, 5]])
np.testing.assert_array_equal(dynamic_shapes, [[16, -1, 83], [5, 5]])
np.testing.assert_array_equal(dynamic_shapes, [[16, 80, 83], [5, 5]])
np.testing.assert_array_equal(estimate_dynamic_shapes, [[16, None, 83], [5, 5]])
def test_get_dynamic_min_max_shapes_4():
def test_output_shapes_3():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with dynamic setting for column with fixed shape
Feature: Test output_shapes
Description: Test output_shapes with NumpySlicesDataset
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes with dynamic setting for column with fixed shape.")
dataset = ds.GeneratorDataset(generator2, ["data1", "data2"])
# new api
estimate_dynamic_shapes = dataset.output_shapes(estimate=True)
# old api
dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": [None, 5]})
dynamic_shapes = dataset.output_shapes()
min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
# check result
# column with fixed shape will also be appended to shapes list
np.testing.assert_array_equal(min_shapes, [[16, 1, 83], [1, 5]])
np.testing.assert_array_equal(max_shapes, [[16, 99, 83], [5, 5]])
np.testing.assert_array_equal(dynamic_shapes, [[16, -1, 83], [-1, 5]])
np.testing.assert_array_equal(estimate_dynamic_shapes, [[16, None, 83], [5, 5]])
def test_get_dynamic_min_max_shapes_5():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with NumpySlicesDataset
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes with NumpySlicesDataset.")
logger.info("Test output_shapes with NumpySlicesDataset.")
np_data = [
[[1, 2], [3, 4]],
@ -179,64 +113,14 @@ def test_get_dynamic_min_max_shapes_5():
estimate_dynamic_shapes = dataset.output_shapes(estimate=True)
# old api
dataset.set_dynamic_columns(columns={"col1": [2, None]})
dynamic_shapes = dataset.output_shapes()
min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
# check result
# column with fixed shape will also be appended to shapes list
np.testing.assert_array_equal(min_shapes, [[2, 1]])
np.testing.assert_array_equal(max_shapes, [[2, 2]])
np.testing.assert_array_equal(dynamic_shapes, [[2, -1]])
np.testing.assert_array_equal(dynamic_shapes, [[2, 2]])
np.testing.assert_array_equal(estimate_dynamic_shapes, [[2, 2]])
def test_get_dynamic_min_max_shapes_6():
"""
Feature: dynamic_min_max_shapes
Description: Test dynamic_min_max_shapes with unexpected column setting
Expectation: The dataset is processed as expected
"""
logger.info("Test dynamic_min_max_shapes with unexpected column setting.")
dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
with pytest.raises(TypeError) as info:
# dynamic column is not in dict
dataset.set_dynamic_columns(columns=list())
assert "Pass a dict to set dynamic shape" in str(info.value)
with pytest.raises(RuntimeError) as info:
# dynamic column is not set
dataset.set_dynamic_columns(columns=dict())
dataset.dynamic_min_max_shapes()
assert "dynamic_columns is not set, call set_dynamic_columns()" in str(info.value)
with pytest.raises(RuntimeError) as info:
# dynamic column is not set
dataset.set_dynamic_columns(columns={"data2": []})
dataset.dynamic_min_max_shapes()
assert "column [data1] has dynamic shape but not set by set_dynamic_columns()" in str(info.value)
with pytest.raises(RuntimeError) as info:
# column does not exist
dataset.set_dynamic_columns(columns={"data3": [16, None, 83]})
dataset.dynamic_min_max_shapes()
assert "dynamic column [data3] does not match any column in dataset" in str(info.value)
with pytest.raises(RuntimeError) as info:
# unexpected column shape
dataset.set_dynamic_columns(columns={"data1": [16, 83, None]})
dataset.dynamic_min_max_shapes()
assert "shape [16, 83, None] does not match dataset column [data1] with shape [16, 1, 83]" in str(info.value)
with pytest.raises(RuntimeError) as info:
# unexpected column shape
dataset.set_dynamic_columns(columns={"data1": [16, None]})
dataset.dynamic_min_max_shapes()
assert "shape [16, None] does not match dataset column [data1] with shape [16, 1, 83]" in str(info.value)
class Generator3:
def __init__(self):
self.data = [np.array([[1], [2]]), np.array([1, 2])]
@ -271,8 +155,4 @@ if __name__ == "__main__":
test_get_dynamic_min_max_shapes_1()
test_get_dynamic_min_max_shapes_2()
test_get_dynamic_min_max_shapes_3()
test_get_dynamic_min_max_shapes_4()
test_get_dynamic_min_max_shapes_5()
test_get_dynamic_min_max_shapes_6()
test_output_shapes_exception()