From 4c2affdaad5887072d08aede19d2fcda5c07b627 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Thu, 19 May 2022 22:41:18 +0800 Subject: [PATCH] Add epoch number for lenet train --- tests/st/frontend_compile_cache/run_lenet_ps.py | 2 +- tests/st/ps/full_ps/test_full_ps_lenet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/st/frontend_compile_cache/run_lenet_ps.py b/tests/st/frontend_compile_cache/run_lenet_ps.py index 07a735b59c9..cb364c05774 100644 --- a/tests/st/frontend_compile_cache/run_lenet_ps.py +++ b/tests/st/frontend_compile_cache/run_lenet_ps.py @@ -127,7 +127,7 @@ if __name__ == "__main__": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) ds_train = create_dataset(os.path.join(DATASET_PATH, "train"), 32, 1) - model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) + model.train(3, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) ds_eval = create_dataset(os.path.join(DATASET_PATH, "test"), 32, 1) acc = model.eval(ds_eval, dataset_sink_mode=False) diff --git a/tests/st/ps/full_ps/test_full_ps_lenet.py b/tests/st/ps/full_ps/test_full_ps_lenet.py index 0374401798b..6c58d4d2c9c 100644 --- a/tests/st/ps/full_ps/test_full_ps_lenet.py +++ b/tests/st/ps/full_ps/test_full_ps_lenet.py @@ -129,7 +129,7 @@ if __name__ == "__main__": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1) - model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) + model.train(3, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1) acc = model.eval(ds_eval, dataset_sink_mode=False)