!812 [CR] add lenet train and eval st case

Merge pull request !812 from jinyaohui/train_eval
This commit is contained in:
mindspore-ci-bot 2020-05-07 18:01:21 +08:00 committed by Gitee
commit 8a484dbd0b
1 changed files with 72 additions and 7 deletions

View File

@ -13,18 +13,26 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import os
import pytest import pytest
import numpy as np import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn.optim import Momentum import mindspore.context as context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn import TrainOneStepCell, WithLossCell import mindspore.nn as nn
from mindspore.nn import Dense from mindspore.nn import Dense, TrainOneStepCell, WithLossCell
from mindspore.common.initializer import initializer from mindspore.nn.optim import Momentum
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train.callback import LossMonitor
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -64,7 +72,7 @@ class LeNet(nn.Cell):
def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32): def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
lr = [] lr = []
for step in range(total_steps): for step in range(total_steps):
lr_ = base_lr * gamma ** (step//gap) lr_ = base_lr * gamma ** (step // gap)
lr.append(lr_) lr.append(lr_)
return Tensor(np.array(lr), dtype) return Tensor(np.array(lr), dtype)
@ -90,3 +98,60 @@ def test_train_lenet():
loss = train_network(data, label) loss = train_network(data, label)
losses.append(loss) losses.append(loss)
print(losses) print(losses)
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_and_eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_mem_reuse=False)
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True)
print("============== Starting Testing ==============")
ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1)
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== Accuracy:{} ==============".format(acc))