diff --git a/mindspore/common/seed.py b/mindspore/common/seed.py index cb32d2ce9a5..88c15fc3031 100644 --- a/mindspore/common/seed.py +++ b/mindspore/common/seed.py @@ -14,6 +14,7 @@ # ============================================================================ """Provide random seed api.""" import numpy as np +import mindspore.dataset as de # set global RNG seed _GLOBAL_SEED = None @@ -43,6 +44,7 @@ def set_seed(seed): if seed < 0: raise ValueError("The seed must be greater or equal to 0.") np.random.seed(seed) + de.config.set_seed(seed) global _GLOBAL_SEED _GLOBAL_SEED = seed diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 18fd5f64eb1..b7c10608690 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -23,7 +23,6 @@ from functools import wraps import numpy as np from mindspore._c_expression import typing -from mindspore.dataset.callback import DSCallback from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ @@ -31,8 +30,6 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis from . import datasets from . import samplers -# from . import cache_client -from .. import callback def check_imagefolderdataset(method): @@ -566,6 +563,7 @@ def check_map(method): @wraps(method) def new_method(self, *args, **kwargs): + from mindspore.dataset.callback import DSCallback [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache, callbacks], _ = \ parse_user_args(method, *args, **kwargs) @@ -581,9 +579,9 @@ def check_map(method): if callbacks is not None: if isinstance(callbacks, (list, tuple)): - type_check_list(callbacks, (callback.DSCallback,), "callbacks") + type_check_list(callbacks, (DSCallback,), "callbacks") else: - type_check(callbacks, (callback.DSCallback,), "callbacks") + type_check(callbacks, (DSCallback,), "callbacks") for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]): if param is not None: diff --git a/model_zoo/official/cv/googlenet/src/dataset.py b/model_zoo/official/cv/googlenet/src/dataset.py index e8872303a97..494ba44f748 100644 --- a/model_zoo/official/cv/googlenet/src/dataset.py +++ b/model_zoo/official/cv/googlenet/src/dataset.py @@ -26,7 +26,6 @@ from src.config import cifar_cfg, imagenet_cfg def create_dataset_cifar10(data_home, repeat_num=1, training=True): """Data operations.""" - ds.config.set_seed(1) data_dir = os.path.join(data_home, "cifar-10-batches-bin") if not training: data_dir = os.path.join(data_home, "cifar-10-verify-bin") diff --git a/model_zoo/official/cv/vgg16/src/dataset.py b/model_zoo/official/cv/vgg16/src/dataset.py index e87947e4bd0..1361eb08f2f 100644 --- a/model_zoo/official/cv/vgg16/src/dataset.py +++ b/model_zoo/official/cv/vgg16/src/dataset.py @@ -28,7 +28,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True): """Data operations.""" - de.config.set_seed(1) data_dir = os.path.join(data_home, "cifar-10-batches-bin") if not training: data_dir = os.path.join(data_home, "cifar-10-verify-bin")