!12977 pynative gpu benchmark
From: @jojobugfree Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
c99929a950
|
@ -94,11 +94,19 @@ class MyTimeMonitor(Callback):
|
|||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16",
|
||||
device_num=1):
|
||||
if args_opt.mode == "GRAPH":
|
||||
ds_num_parallel_worker = 4
|
||||
map_num_parallel_worker = 8
|
||||
batch_num_parallel_worker = None
|
||||
else:
|
||||
ds_num_parallel_worker = 2
|
||||
map_num_parallel_worker = 3
|
||||
batch_num_parallel_worker = 2
|
||||
ds.config.set_numa_enable(True)
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True)
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True,
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True,
|
||||
num_shards=device_num, shard_id=get_rank())
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
|
@ -127,9 +135,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
]
|
||||
if dtype == "fp32":
|
||||
trans.append(C.HWC2CHW())
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=map_num_parallel_worker)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=batch_num_parallel_worker)
|
||||
# apply dataset repeat operation
|
||||
if repeat_num > 1:
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
@ -165,14 +173,16 @@ def train():
|
|||
# init context
|
||||
if args_opt.mode == "GRAPH":
|
||||
mode = context.GRAPH_MODE
|
||||
all_reduce_fusion_config = [85, 160]
|
||||
else:
|
||||
mode = context.PYNATIVE_MODE
|
||||
all_reduce_fusion_config = [30, 90, 160]
|
||||
context.set_context(mode=mode, device_target=dev, save_graphs=False)
|
||||
if args_opt.run_distribute:
|
||||
init()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, all_reduce_fusion_config=[85, 160])
|
||||
gradients_mean=True, all_reduce_fusion_config=all_reduce_fusion_config)
|
||||
ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
|
|
Loading…
Reference in New Issue