!17300 Enable Graph Kernel for Lenet and LenetQuant on GPU

From: @zengzitao
Reviewed-by: @ckey_dou,@gaoxiong1
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-31 11:35:04 +08:00 committed by Gitee
commit 53e8cacede
2 changed files with 7 additions and 0 deletions

View File

@ -33,9 +33,11 @@ from mindspore.common import set_seed
set_seed(1)
def modelarts_pre_process():
pass
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_lenet():
@ -53,6 +55,8 @@ def train_lenet():
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)
if config.device_target != "Ascend":
if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
@ -60,5 +64,6 @@ def train_lenet():
print("============== Starting Training ==============")
model.train(config.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
if __name__ == "__main__":
train_lenet()

View File

@ -51,6 +51,8 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()
if args.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
# define fusion network
network = LeNet5Fusion(cfg.num_classes)