forked from mindspore-Ecosystem/mindspore
Add epoch number for lenet train
This commit is contained in:
parent
b5249cdedd
commit
4c2affdaad
|
@ -127,7 +127,7 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
ds_train = create_dataset(os.path.join(DATASET_PATH, "train"), 32, 1)
|
||||
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
model.train(3, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
ds_eval = create_dataset(os.path.join(DATASET_PATH, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
|
|
|
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
|
||||
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
model.train(3, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
|
|
Loading…
Reference in New Issue