!6113 Unify minddata seed to set_seed

Merge pull request !6113 from xiefangqi/md_unify_seed
This commit is contained in:
mindspore-ci-bot 2020-09-14 16:20:04 +08:00 committed by Gitee
commit 10aec24510
4 changed files with 5 additions and 7 deletions

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Provide random seed api."""
import numpy as np
import mindspore.dataset as de
# set global RNG seed
_GLOBAL_SEED = None
@ -43,6 +44,7 @@ def set_seed(seed):
if seed < 0:
raise ValueError("The seed must be greater or equal to 0.")
np.random.seed(seed)
de.config.set_seed(seed)
global _GLOBAL_SEED
_GLOBAL_SEED = seed

View File

@ -23,7 +23,6 @@ from functools import wraps
import numpy as np
from mindspore._c_expression import typing
from mindspore.dataset.callback import DSCallback
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
@ -31,8 +30,6 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
from . import datasets
from . import samplers
# from . import cache_client
from .. import callback
def check_imagefolderdataset(method):
@ -566,6 +563,7 @@ def check_map(method):
@wraps(method)
def new_method(self, *args, **kwargs):
from mindspore.dataset.callback import DSCallback
[_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
callbacks], _ = \
parse_user_args(method, *args, **kwargs)
@ -581,9 +579,9 @@ def check_map(method):
if callbacks is not None:
if isinstance(callbacks, (list, tuple)):
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
type_check_list(callbacks, (DSCallback,), "callbacks")
else:
type_check(callbacks, (callback.DSCallback,), "callbacks")
type_check(callbacks, (DSCallback,), "callbacks")
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]):
if param is not None:

View File

@ -26,7 +26,6 @@ from src.config import cifar_cfg, imagenet_cfg
def create_dataset_cifar10(data_home, repeat_num=1, training=True):
"""Data operations."""
ds.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")

View File

@ -28,7 +28,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True):
"""Data operations."""
de.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")