forked from mindspore-Ecosystem/mindspore
!34683 Add epoch number for lenet train
Merge pull request !34683 from ZPaC/fix-full-ps-lenet-acc
This commit is contained in:
commit
57216a7b80
|
@ -127,7 +127,7 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
ds_train = create_dataset(os.path.join(DATASET_PATH, "train"), 32, 1)
|
||||
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
model.train(3, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
ds_eval = create_dataset(os.path.join(DATASET_PATH, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
|
|
|
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
|
||||
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
model.train(3, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
|
|
Loading…
Reference in New Issue