diff --git a/mindspore/_extends/graph_kernel/expanders/conv2d.py b/mindspore/_extends/graph_kernel/expanders/conv2d.py index ac93090aaaa..24908d25d6c 100644 --- a/mindspore/_extends/graph_kernel/expanders/conv2d.py +++ b/mindspore/_extends/graph_kernel/expanders/conv2d.py @@ -21,7 +21,7 @@ from ._utils import Expander, ExpanderInfoValidator as VLD M_ALIGN = 16 N_ALIGN = 16 K_ALIGN = 8 -K_LIMIT = 4096 +K_LIMIT = 800 MNK_LIMIT = 3 * (10 ** 10) N0_CHANNEL_ALIGN = 16 N1_CHANNEL_ALIGN = 16 diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py index 17312ee382a..663817a4e75 100644 --- a/model_zoo/official/cv/alexnet/train.py +++ b/model_zoo/official/cv/alexnet/train.py @@ -58,6 +58,7 @@ def train_alexnet(): context.set_context(save_graphs=False) if device_target == "GPU": context.set_context(enable_graph_kernel=True) + context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul") device_num = get_device_num() if config.dataset_name == "cifar10": @@ -124,7 +125,8 @@ def train_alexnet(): model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) elif device_target == "GPU": - model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, loss_scale_manager=loss_scale_manager) + model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", + loss_scale_manager=loss_scale_manager) else: raise ValueError("Unsupported platform.")