!12977 pynative gpu benchmark

From: @jojobugfree
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-03-08 21:26:56 +08:00 committed by Gitee
commit c99929a950
1 changed files with 15 additions and 5 deletions

View File

@ -94,11 +94,19 @@ class MyTimeMonitor(Callback):
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16",
device_num=1):
if args_opt.mode == "GRAPH":
ds_num_parallel_worker = 4
map_num_parallel_worker = 8
batch_num_parallel_worker = None
else:
ds_num_parallel_worker = 2
map_num_parallel_worker = 3
batch_num_parallel_worker = 2
ds.config.set_numa_enable(True)
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True,
num_shards=device_num, shard_id=get_rank())
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
@ -127,9 +135,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
]
if dtype == "fp32":
trans.append(C.HWC2CHW())
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=map_num_parallel_worker)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=batch_num_parallel_worker)
# apply dataset repeat operation
if repeat_num > 1:
data_set = data_set.repeat(repeat_num)
@ -165,14 +173,16 @@ def train():
# init context
if args_opt.mode == "GRAPH":
mode = context.GRAPH_MODE
all_reduce_fusion_config = [85, 160]
else:
mode = context.PYNATIVE_MODE
all_reduce_fusion_config = [30, 90, 160]
context.set_context(mode=mode, device_target=dev, save_graphs=False)
if args_opt.run_distribute:
init()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[85, 160])
gradients_mean=True, all_reduce_fusion_config=all_reduce_fusion_config)
ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/"
# create dataset