forked from mindspore-Ecosystem/mindspore
get default value if num_parallel_worker is None
This commit is contained in:
parent
fe82d82155
commit
365c901ee0
|
@ -22,7 +22,7 @@ import sys
|
|||
from mindspore import log as logger
|
||||
from . import datasets as de
|
||||
from ..transforms.vision.utils import Inter, Border
|
||||
|
||||
from ..core.configuration import config
|
||||
|
||||
def serialize(dataset, json_filepath=None):
|
||||
"""
|
||||
|
@ -164,6 +164,8 @@ def traverse(node):
|
|||
node_repr[k] = v.to_json()
|
||||
elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
|
||||
expand_path(node_repr, k, v)
|
||||
elif k == "num_parallel_workers" and v is None:
|
||||
node_repr[k] = config.get_num_parallel_workers()
|
||||
else:
|
||||
node_repr[k] = v
|
||||
|
||||
|
|
|
@ -84,12 +84,11 @@ def test_pipeline():
|
|||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_num_parallel_workers(2)
|
||||
data1 = data1.map(input_columns=["image"], operations=[c_vision.Decode(True)])
|
||||
ds.serialize(data1, "testpipeline.json")
|
||||
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=num_parallel_workers_original,
|
||||
shuffle=False)
|
||||
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode(True)])
|
||||
ds.serialize(data2, "testpipeline2.json")
|
||||
|
||||
|
|
Loading…
Reference in New Issue