!812 [CR] add lenet train and eval st case

Merge pull request !812 from jinyaohui/train_eval
This commit is contained in:
mindspore-ci-bot 2020-05-07 18:01:21 +08:00 committed by Gitee
commit 8a484dbd0b
1 changed files with 72 additions and 7 deletions

View File

@ -13,18 +13,26 @@
# limitations under the License.
# ============================================================================
import os
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn import Dense
from mindspore.common.initializer import initializer
import mindspore.nn as nn
from mindspore.nn import Dense, TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train.callback import LossMonitor
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -64,7 +72,7 @@ class LeNet(nn.Cell):
def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
lr = []
for step in range(total_steps):
lr_ = base_lr * gamma ** (step//gap)
lr_ = base_lr * gamma ** (step // gap)
lr.append(lr_)
return Tensor(np.array(lr), dtype)
@ -90,3 +98,60 @@ def test_train_lenet():
loss = train_network(data, label)
losses.append(loss)
print(losses)
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_and_eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_mem_reuse=False)
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True)
print("============== Starting Testing ==============")
ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1)
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== Accuracy:{} ==============".format(acc))