diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index ded88839c48..8adb1505c73 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -32,18 +32,6 @@ INT32_MAX = 2147483647 UINT32_MAX = 4294967295 _config = cde.GlobalContext.config_manager() -_dynamic_columns = dict() - - -def set_dynamic_columns(columns=None): - global _dynamic_columns - if not isinstance(columns, dict): - raise TypeError("Pass a dict to set dynamic shape, example: {\"data1\": [16, None, 256]}") - _dynamic_columns = columns - - -def get_dynamic_columns(): - return _dynamic_columns def _init_device_info(): diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 119be4eb45f..961eba74249 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -60,7 +60,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \ check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ - get_prefetch_size, get_dynamic_columns + get_prefetch_size from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.validator_helpers import replace_none from ..core.py_util_helpers import ExceptionHandler @@ -209,15 +209,15 @@ class Dataset: self._input_indexs = () self.saved_output_types = None self.saved_output_shapes = None + self.dynamic_setting = [False, None] + self.saved_min_shapes = None + self.saved_max_shapes = None self._col_names = None self.dataset_size = None self._batch_size = None self._num_classes = None self._repeat_count = None self._class_indexing = None - self.min_shapes = None - self.max_shapes = None - self.dynamic_shapes = None self._sync = False def create_ir_tree(self): @@ -1531,8 +1531,9 @@ class Dataset: if self.saved_output_shapes is None: runtime_getter = self._init_tree_getters() self.saved_output_shapes = runtime_getter[0].GetOutputShapes() - self.saved_output_types = runtime_getter[0].GetOutputTypes() self.close_pool() + if self.dynamic_setting[0]: + self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes() return self.saved_output_shapes def output_types(self): @@ -1544,7 +1545,6 @@ class Dataset: """ if self.saved_output_types is None: runtime_getter = self._init_tree_getters() - self.saved_output_shapes = runtime_getter[0].GetOutputShapes() self.saved_output_types = runtime_getter[0].GetOutputTypes() self.close_pool() return self.saved_output_types @@ -1562,24 +1562,35 @@ class Dataset: self.close_pool() return self.dataset_size - def get_dynamic_min_max_shape(self): + def set_dynamic_columns(self, columns=None): + if not isinstance(columns, dict): + raise TypeError("Pass a dict to set dynamic shape, example: {\"data1\": [16, None, 256]}") + self.dynamic_setting[0] = True + self.dynamic_setting[1] = columns + + def dynamic_min_max_shapes(self): + if self.saved_min_shapes is None or self.saved_max_shapes is None: + self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes() + return self.saved_min_shapes, self.saved_max_shapes + + def _dynamic_output_shapes(self): """ Get dynamic information of source data. Returns: - lists, min_shapes, max_shapes, dynamic_shapes of source data. + lists, dynamic_shapes, min_shapes, max_shapes of source data. """ - # Assume data1 shape is dynamic, data2 shape is fix - # {"data1": [batch_size, None, feat_len], "data2": [batch_size, feat_len]} - dynamic_columns = get_dynamic_columns() - if not dynamic_columns: + if not self.dynamic_setting[1]: raise RuntimeError("dynamic_columns is not set, call set_dynamic_columns() first.") - if self.min_shapes is not None and self.max_shapes is not None and self.dynamic_shapes is not None: - return self.min_shapes, self.max_shapes, self.dynamic_shapes + if self.saved_output_shapes is not None and self.saved_min_shapes is not None and \ + self.saved_max_shapes is not None: + return self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes logger.warning("Calculating dynamic shape of input data, this will take a few minutes...") - + # Assume data1 shape is dynamic, data2 shape is fix + # {"data1": [batch_size, None, feat_len], "data2": [batch_size, feat_len]} + dynamic_columns = self.dynamic_setting[1] # ["data1", "data2"] dataset_columns = self.get_col_names() for column in dynamic_columns: @@ -1633,10 +1644,7 @@ class Dataset: max_shapes.append(fix_shape) min_shapes.append(fix_shape) dynamic_shapes.append(fix_shape) - self.min_shapes = min_shapes - self.max_shapes = max_shapes - self.dynamic_shapes = dynamic_shapes - return self.min_shapes, self.max_shapes, self.dynamic_shapes + return dynamic_shapes, min_shapes, max_shapes def num_classes(self): """ diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 61d0be53f32..39df6c954b4 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -254,8 +254,9 @@ class DatasetHelper: """Get the types and shape of current batch.""" return self.iter.get_data_info() - def get_dynamic_min_max_shape(self): - return self.iter.get_dynamic_min_max_shape() + def dynamic_min_max_shapes(self): + """Get shape range(min shape, max shape) of dynamic data.""" + return self.iter.dynamic_min_max_shapes() class _DatasetIter: """Base iter for dataset helper""" @@ -285,7 +286,7 @@ class _DatasetIter: self.release = dataset.__transfer_dataset__.release self.continue_send = dataset.__transfer_dataset__.continue_send self.get_data_info = dataset.__transfer_dataset__.get_data_info - self.get_dynamic_min_max_shape = dataset.__transfer_dataset__.get_dynamic_min_max_shape + self.dynamic_min_max_shapes = dataset.__transfer_dataset__.dynamic_min_max_shapes self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) def __iter__(self): diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 4159444a7cf..fca98cf72a8 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -73,7 +73,7 @@ class MindData: def get_data_info(self): pass - def get_dynamic_min_max_shape(self): + def dynamic_min_max_shapes(self): pass def __len__(self): diff --git a/tests/ut/python/dataset/test_datasets_get_dynamic_shape.py b/tests/ut/python/dataset/test_datasets_get_dynamic_shape.py index 9d714239ae3..ebfae9bde78 100644 --- a/tests/ut/python/dataset/test_datasets_get_dynamic_shape.py +++ b/tests/ut/python/dataset/test_datasets_get_dynamic_shape.py @@ -23,16 +23,17 @@ def generator0(): yield (np.ones((32, i)), np.zeros((16, i, i, 3)), np.ones((i))) -def test_get_dynamic_min_max_shape_0(): - logger.info("Test get_dynamic_min_max_shape with dynamic shape columns") +def test_get_dynamic_min_max_shapes_0(): + logger.info("Test dynamic_min_max_shapes with dynamic shape columns") dataset = ds.GeneratorDataset(generator0, ["data1", "data2", "data3"]) # config dynamic shape - ds.config.set_dynamic_columns(columns={"data1": [32, None], "data2": [16, None, None, 3], "data3": [None]}) + dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [16, None, None, 3], "data3": [None]}) # get dynamic information - min_shapes, max_shapes, dynamic_shapes = dataset.get_dynamic_min_max_shape() + min_shapes, max_shapes = dataset.dynamic_min_max_shapes() + dynamic_shapes = dataset.output_shapes() # check result np.testing.assert_array_equal(min_shapes, [[32, 1], [16, 1, 1, 3], [1]]) @@ -45,16 +46,17 @@ def generator1(): yield (np.ones((16, i, 83)), np.array((i))) -def test_get_dynamic_min_max_shape_1(): - logger.info("Test get_dynamic_min_max_shape with dynamic shape column and fix shape column") +def test_get_dynamic_min_max_shapes_1(): + logger.info("Test dynamic_min_max_shapes with dynamic shape column and fix shape column") dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) # config dynamic shape - ds.config.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []}) + dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []}) # get dynamic information - min_shapes, max_shapes, dynamic_shapes = dataset.get_dynamic_min_max_shape() + dynamic_shapes = dataset.output_shapes() + min_shapes, max_shapes = dataset.dynamic_min_max_shapes() # check result # raise a warning to tell user "data2" is not dynamic @@ -63,14 +65,15 @@ def test_get_dynamic_min_max_shape_1(): np.testing.assert_array_equal(dynamic_shapes, [[16, -1, 83], []]) -def test_get_dynamic_min_max_shape_2(): - logger.info("Test get_dynamic_min_max_shape with all dynamic config") +def test_get_dynamic_min_max_shapes_2(): + logger.info("Test dynamic_min_max_shapes with all dynamic config") dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) # config all dims have dynamic shape - ds.config.set_dynamic_columns(columns={"data1": [None, None, None]}) - min_shapes, max_shapes, dynamic_shapes = dataset.get_dynamic_min_max_shape() + dataset.set_dynamic_columns(columns={"data1": [None, None, None]}) + dynamic_shapes = dataset.output_shapes() + min_shapes, max_shapes = dataset.dynamic_min_max_shapes() # check result # Although shape[0] of data1 is fix in given data, user think it is dynamic, so shape[0] is dynamic @@ -84,16 +87,17 @@ def generator2(): yield (np.ones((16, i, 83)), np.ones((5, 5))) -def test_get_dynamic_min_max_shape_3(): - logger.info("Test get_dynamic_min_max_shape with only config dynamic column") +def test_get_dynamic_min_max_shapes_3(): + logger.info("Test dynamic_min_max_shapes with only config dynamic column") dataset = ds.GeneratorDataset(generator2, ["data1", "data2"]) # only dynamic shape is required to config - ds.config.set_dynamic_columns(columns={"data1": [16, None, 83]}) + dataset.set_dynamic_columns(columns={"data1": [16, None, 83]}) # get dynamic information - min_shapes, max_shapes, dynamic_shapes = dataset.get_dynamic_min_max_shape() + dynamic_shapes = dataset.output_shapes() + min_shapes, max_shapes = dataset.dynamic_min_max_shapes() # check result # column with fix shape will be also appended to shapes list @@ -102,51 +106,51 @@ def test_get_dynamic_min_max_shape_3(): np.testing.assert_array_equal(dynamic_shapes, [[16, -1, 83], [5, 5]]) -def test_get_dynamic_min_max_shape_4(): - logger.info("Test get_dynamic_min_max_shape with unexpected column setting") +def test_get_dynamic_min_max_shapes_4(): + logger.info("Test dynamic_min_max_shapes with unexpected column setting") dataset = ds.GeneratorDataset(generator1, ["data1", "data2"]) with pytest.raises(TypeError) as info: # dynamic column is not in dict - ds.config.set_dynamic_columns(columns=list()) + dataset.set_dynamic_columns(columns=list()) assert "Pass a dict to set dynamic shape" in str(info.value) with pytest.raises(RuntimeError) as info: # dynamic column is not set - ds.config.set_dynamic_columns(columns=dict()) - dataset.get_dynamic_min_max_shape() + dataset.set_dynamic_columns(columns=dict()) + dataset.dynamic_min_max_shapes() assert "dynamic_columns is not set, call set_dynamic_columns() first" in str(info.value) with pytest.raises(RuntimeError) as info: # dynamic column is not set - ds.config.set_dynamic_columns(columns={"data2": []}) - dataset.get_dynamic_min_max_shape() + dataset.set_dynamic_columns(columns={"data2": []}) + dataset.dynamic_min_max_shapes() assert "column [data1] has dynamic shape but not set by set_dynamic_columns()" in str(info.value) with pytest.raises(RuntimeError) as info: # column does not exist - ds.config.set_dynamic_columns(columns={"data3": [16, None, 83]}) - dataset.get_dynamic_min_max_shape() + dataset.set_dynamic_columns(columns={"data3": [16, None, 83]}) + dataset.dynamic_min_max_shapes() assert "dynamic column [data3] does not match any column in dataset" in str(info.value) with pytest.raises(RuntimeError) as info: # unexpected column shape - ds.config.set_dynamic_columns(columns={"data1": [16, 83, None]}) - dataset.get_dynamic_min_max_shape() + dataset.set_dynamic_columns(columns={"data1": [16, 83, None]}) + dataset.dynamic_min_max_shapes() assert "shape [16, 83, None] does not match dataset column [data1] with shape [16, 1, 83]" in str(info.value) with pytest.raises(RuntimeError) as info: # unexpected column shape - ds.config.set_dynamic_columns(columns={"data1": [16, None]}) - dataset.get_dynamic_min_max_shape() + dataset.set_dynamic_columns(columns={"data1": [16, None]}) + dataset.dynamic_min_max_shapes() assert "shape [16, None] does not match dataset column [data1] with shape [16, 1, 83]" in str(info.value) if __name__ == "__main__": - test_get_dynamic_min_max_shape_0() - test_get_dynamic_min_max_shape_1() - test_get_dynamic_min_max_shape_2() - test_get_dynamic_min_max_shape_3() - test_get_dynamic_min_max_shape_4() + test_get_dynamic_min_max_shapes_0() + test_get_dynamic_min_max_shapes_1() + test_get_dynamic_min_max_shapes_2() + test_get_dynamic_min_max_shapes_3() + test_get_dynamic_min_max_shapes_4() \ No newline at end of file