From d5adfa52100405774c840f2ec9266a83a02ecc27 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Wed, 29 Apr 2020 10:48:15 +0800 Subject: [PATCH] add accuracy for resnet cifar --- tests/st/tbe_networks/test_resnet_cifar_1p.py | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 tests/st/tbe_networks/test_resnet_cifar_1p.py diff --git a/tests/st/tbe_networks/test_resnet_cifar_1p.py b/tests/st/tbe_networks/test_resnet_cifar_1p.py new file mode 100644 index 00000000000..058ec3aeeca --- /dev/null +++ b/tests/st/tbe_networks/test_resnet_cifar_1p.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore import context +import mindspore.common.dtype as mstype +import os +import numpy as np +import mindspore.ops.functional as F +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision +from resnet import resnet50 +import random +import time + +random.seed(1) +np.random.seed(1) +ds.config.set_seed(1) + +data_home = "/home/workspace/mindspore_dataset" + + +def create_dataset(repeat_num=1, training=True, batch_size=32): + data_dir = data_home + "/cifar-10-batches-bin" + if not training: + data_dir = data_home + "/cifar-10-verify-bin" + data_set = ds.Cifar10Dataset(data_dir) + + resize_height = 224 + resize_width = 224 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + random_crop_op = vision.RandomCrop( + (32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) + rescale_op = vision.Rescale(rescale, shift) + normalize_op = vision.Normalize( + (0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, + changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply shuffle operations + data_set = data_set.shuffle(buffer_size=1000) + + # apply batch operations + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + return data_set + + +class CrossEntropyLoss(nn.Cell): + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero) + loss = self.cross_entropy(logits, label)[0] + loss = self.mean(loss, (-1,)) + return loss + + +class LossGet(Callback): + def __init__(self, per_print_times=1): + super(LossGet, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self._loss = 0.0 + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." + .format(cb_params.cur_epoch_num, cur_step_in_epoch)) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + self._loss = loss + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss)) + + def get_loss(self): + return self._loss + + +def train_process(device_id, epoch_size, num_classes, device_num, batch_size): + os.system("mkdir " + str(device_id)) + os.chdir(str(device_id)) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(enable_task_sink=True, device_id=device_id) + context.set_context(enable_loop_sink=True) + context.set_context(enable_mem_reuse=True) + context.set_context(mode=context.GRAPH_MODE) + net = resnet50(batch_size, num_classes) + loss = CrossEntropyLoss() + opt = Momentum(filter(lambda x: x.requires_grad, + net.get_parameters()), 0.01, 0.9) + + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) + batch_num = dataset.get_dataset_size() + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./", + config=config_ck) + loss_cb = LossGet() + model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) + + +def eval(batch_size, num_classes): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(enable_task_sink=True, device_id=0) + context.set_context(enable_loop_sink=True) + context.set_context(enable_mem_reuse=True) + + net = resnet50(batch_size, num_classes) + loss = CrossEntropyLoss() + opt = Momentum(filter(lambda x: x.requires_grad, + net.get_parameters()), 0.01, 0.9) + + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt" + param_dict = load_checkpoint(checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + eval_dataset = create_dataset(1, training=False) + res = model.eval(eval_dataset) + print("result: ", res) + return res + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resnet_cifar_1p(): + device_num = 1 + epoch_size = 1 + num_classes = 10 + batch_size = 32 + device_id = 0 + train_process(device_id, epoch_size, num_classes, device_num, batch_size) + time.sleep(3) + acc = eval(batch_size, num_classes) + os.chdir("../") + os.system("rm -rf " + str(device_id)) + print("End training...") + assert (acc['acc'] > 0.35)