forked from mindspore-Ecosystem/mindspore
Update submodule mindspore/akg
Enable Graph Kernel for W&D Host-Device Mode and ResNext50 on GPU Enable akg-conv2d in yolov3_darknet on GPU
This commit is contained in:
parent
43174475e6
commit
448d1cdbb6
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit 83df651adb471ec229fc1732b894dd2142946233
|
||||
Subproject commit 32af460cac1bb7d76bc1fd41f5866107cfffe1b9
|
|
@ -100,12 +100,14 @@ class Conv2D(Expander):
|
|||
check_nd(stride, 4)
|
||||
n0, h0, w0, c0 = shape_0
|
||||
n1, h1, w1, c1 = shape_1
|
||||
if n0 < N0_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of first input should >= {}".format(n0, N0_CHANNEL_ALIGN))
|
||||
if n0 <= N0_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of first input should > {}".format(n0, N0_CHANNEL_ALIGN))
|
||||
if n1 < N1_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN))
|
||||
if c0 != c1 or c0 < C_CHANNEL_ALIGN:
|
||||
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN))
|
||||
if stride != [1, 1, 2, 2]:
|
||||
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3]))
|
||||
# n0 pad
|
||||
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
|
||||
# h0, w0 pad
|
||||
|
|
|
@ -121,7 +121,8 @@ def apply_eval(eval_param):
|
|||
|
||||
def set_graph_kernel_context(run_platform, net_name):
|
||||
if run_platform == "GPU" and net_name == "resnet101":
|
||||
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_parallel_fusion --enable_expand_ops=Conv2D")
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
|
|
@ -123,6 +123,10 @@ def get_result(model, top1_correct, top5_correct, img_tot):
|
|||
config.logger.info('after results=%s', results)
|
||||
return results
|
||||
|
||||
def set_graph_kernel_context(device_target):
|
||||
if device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
|
||||
@moxing_wrapper()
|
||||
def test(cloud_args=None):
|
||||
"""test"""
|
||||
|
@ -131,6 +135,7 @@ def test(cloud_args=None):
|
|||
device_target=config.device_target, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
set_graph_kernel_context(config.device_target)
|
||||
|
||||
# init distributed
|
||||
if config.run_distribute:
|
||||
|
|
|
@ -134,12 +134,17 @@ def set_parameters():
|
|||
config.logger = get_logger(config.outputs_dir, config.rank)
|
||||
return config
|
||||
|
||||
def set_graph_kernel_context(device_target):
|
||||
if device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
|
||||
@moxing_wrapper()
|
||||
def train():
|
||||
"""training process"""
|
||||
set_parameters()
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
set_graph_kernel_context(config.device_target)
|
||||
|
||||
# init distributed
|
||||
if config.run_distribute:
|
||||
|
|
|
@ -64,7 +64,8 @@ def set_graph_kernel_context():
|
|||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_parallel_fusion "
|
||||
"--disable_expand_ops=BatchNorm,BatchNormGrad "
|
||||
"--disable_cluster_ops=ReduceMax,Reshape")
|
||||
"--disable_cluster_ops=ReduceMax,Reshape "
|
||||
"--enable_expand_ops=Conv2D")
|
||||
|
||||
def network_init(args):
|
||||
devid = int(os.getenv('DEVICE_ID', '0'))
|
||||
|
|
|
@ -145,6 +145,8 @@ def train_and_eval(config):
|
|||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=cfg.device_target)
|
||||
if cfg.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
init()
|
||||
|
|
Loading…
Reference in New Issue