forked from mindspore-Ecosystem/mindspore
!6113 Unify minddata seed to set_seed
Merge pull request !6113 from xiefangqi/md_unify_seed
This commit is contained in:
commit
10aec24510
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Provide random seed api."""
|
||||
import numpy as np
|
||||
import mindspore.dataset as de
|
||||
|
||||
# set global RNG seed
|
||||
_GLOBAL_SEED = None
|
||||
|
@ -43,6 +44,7 @@ def set_seed(seed):
|
|||
if seed < 0:
|
||||
raise ValueError("The seed must be greater or equal to 0.")
|
||||
np.random.seed(seed)
|
||||
de.config.set_seed(seed)
|
||||
global _GLOBAL_SEED
|
||||
_GLOBAL_SEED = seed
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ from functools import wraps
|
|||
|
||||
import numpy as np
|
||||
from mindspore._c_expression import typing
|
||||
from mindspore.dataset.callback import DSCallback
|
||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||
|
@ -31,8 +30,6 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
|
|||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
# from . import cache_client
|
||||
from .. import callback
|
||||
|
||||
|
||||
def check_imagefolderdataset(method):
|
||||
|
@ -566,6 +563,7 @@ def check_map(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
from mindspore.dataset.callback import DSCallback
|
||||
[_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
|
||||
callbacks], _ = \
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
|
@ -581,9 +579,9 @@ def check_map(method):
|
|||
|
||||
if callbacks is not None:
|
||||
if isinstance(callbacks, (list, tuple)):
|
||||
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
|
||||
type_check_list(callbacks, (DSCallback,), "callbacks")
|
||||
else:
|
||||
type_check(callbacks, (callback.DSCallback,), "callbacks")
|
||||
type_check(callbacks, (DSCallback,), "callbacks")
|
||||
|
||||
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]):
|
||||
if param is not None:
|
||||
|
|
|
@ -26,7 +26,6 @@ from src.config import cifar_cfg, imagenet_cfg
|
|||
|
||||
def create_dataset_cifar10(data_home, repeat_num=1, training=True):
|
||||
"""Data operations."""
|
||||
ds.config.set_seed(1)
|
||||
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||
|
|
|
@ -28,7 +28,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|||
|
||||
def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True):
|
||||
"""Data operations."""
|
||||
de.config.set_seed(1)
|
||||
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||
|
|
Loading…
Reference in New Issue