forked from mindspore-Ecosystem/mindspore
!17300 Enable Graph Kernel for Lenet and LenetQuant on GPU
From: @zengzitao Reviewed-by: @ckey_dou,@gaoxiong1 Signed-off-by: @ckey_dou
This commit is contained in:
commit
53e8cacede
|
@ -33,9 +33,11 @@ from mindspore.common import set_seed
|
|||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_lenet():
|
||||
|
||||
|
@ -53,6 +55,8 @@ def train_lenet():
|
|||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)
|
||||
|
||||
if config.device_target != "Ascend":
|
||||
if config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
else:
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
|
||||
|
@ -60,5 +64,6 @@ def train_lenet():
|
|||
print("============== Starting Training ==============")
|
||||
model.train(config.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_lenet()
|
||||
|
|
|
@ -51,6 +51,8 @@ if __name__ == "__main__":
|
|||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
if args.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
|
||||
|
|
Loading…
Reference in New Issue