From 602a8e52a018f1b17bd95367f07928bf97e615f3 Mon Sep 17 00:00:00 2001 From: liyong Date: Tue, 1 Dec 2020 15:42:57 +0800 Subject: [PATCH] fix bug in dataset helper --- mindspore/train/dataset_helper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 26f09cc733c..78c8cc14926 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -106,6 +106,9 @@ def connect_network_with_dataset(network, dataset_helper): if hasattr(dataset, '__network_manage__') and key in dataset.__network_manage__: network = dataset.__network_manage__[key] else: + if _need_to_full(): + device_num = _get_device_num() + dataset_shapes = _to_full_shapes(dataset_shapes, device_num) network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) dataset.__network_manage__ = dataset.__network_manage__ if hasattr( dataset, '__network_manage__') else dict()