forked from mindspore-Ecosystem/mindspore
!2766 get default value if num_parallel_workers is None
Merge pull request !2766 from yanghaitao/yht_serialize_num_parallel_worker
This commit is contained in:
commit
9377e432d2
|
@ -22,7 +22,7 @@ import sys
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from . import datasets as de
|
from . import datasets as de
|
||||||
from ..transforms.vision.utils import Inter, Border
|
from ..transforms.vision.utils import Inter, Border
|
||||||
|
from ..core.configuration import config
|
||||||
|
|
||||||
def serialize(dataset, json_filepath=None):
|
def serialize(dataset, json_filepath=None):
|
||||||
"""
|
"""
|
||||||
|
@ -164,6 +164,8 @@ def traverse(node):
|
||||||
node_repr[k] = v.to_json()
|
node_repr[k] = v.to_json()
|
||||||
elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
|
elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
|
||||||
expand_path(node_repr, k, v)
|
expand_path(node_repr, k, v)
|
||||||
|
elif k == "num_parallel_workers" and v is None:
|
||||||
|
node_repr[k] = config.get_num_parallel_workers()
|
||||||
else:
|
else:
|
||||||
node_repr[k] = v
|
node_repr[k] = v
|
||||||
|
|
||||||
|
|
|
@ -84,12 +84,11 @@ def test_pipeline():
|
||||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||||
|
|
||||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||||
ds.config.set_num_parallel_workers(2)
|
|
||||||
data1 = data1.map(input_columns=["image"], operations=[c_vision.Decode(True)])
|
data1 = data1.map(input_columns=["image"], operations=[c_vision.Decode(True)])
|
||||||
ds.serialize(data1, "testpipeline.json")
|
ds.serialize(data1, "testpipeline.json")
|
||||||
|
|
||||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=num_parallel_workers_original,
|
||||||
ds.config.set_num_parallel_workers(4)
|
shuffle=False)
|
||||||
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode(True)])
|
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode(True)])
|
||||||
ds.serialize(data2, "testpipeline2.json")
|
ds.serialize(data2, "testpipeline2.json")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue