diff --git a/tests/st/quantization/mobilenetv2_quant/dataset.py b/tests/st/quantization/mobilenetv2_quant/dataset.py new file mode 100644 index 00000000000..6dd158f98ca --- /dev/null +++ b/tests/st/quantization/mobilenetv2_quant/dataset.py @@ -0,0 +1,67 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" create train dataset. """ + +from functools import partial +import mindspore.dataset as ds +import mindspore.common.dtype as mstype +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 + + +def create_dataset(dataset_path, config, repeat_num=1, batch_size=32): + """ + create a train dataset + + Args: + dataset_path(string): the path of dataset. + config(EasyDict):the basic config for training + repeat_num(int): the repeat times of dataset. Default: 1. + batch_size(int): the batch size of dataset. Default: 32. + + Returns: + dataset + """ + + load_func = partial(ds.Cifar10Dataset, dataset_path) + cifar_ds = load_func(num_parallel_workers=8, shuffle=False) + + resize_height = config.image_height + resize_width = config.image_width + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + # interpolation default BILINEAR + resize_op = C.Resize((resize_height, resize_width)) + rescale_op = C.Rescale(rescale, shift) + normalize_op = C.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + changeswap_op = C.HWC2CHW() + type_cast_op = C2.TypeCast(mstype.int32) + + c_trans = [resize_op, rescale_op, normalize_op, changeswap_op] + + # apply map operations on images + cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op) + cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans) + + # apply batch operations + cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + cifar_ds = cifar_ds.repeat(repeat_num) + + return cifar_ds diff --git a/tests/st/quantization/mobilenetv2_quant/lr_generator.py b/tests/st/quantization/mobilenetv2_quant/lr_generator.py new file mode 100644 index 00000000000..bc6ff8106e2 --- /dev/null +++ b/tests/st/quantization/mobilenetv2_quant/lr_generator.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" + +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / + (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/tests/st/quantization/mobilenetv2_quant/mobilenetV2.py b/tests/st/quantization/mobilenetv2_quant/mobilenetV2.py new file mode 100644 index 00000000000..969dd6cfb1d --- /dev/null +++ b/tests/st/quantization/mobilenetv2_quant/mobilenetV2.py @@ -0,0 +1,263 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 Quant model define""" + +import numpy as np + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor + +__all__ = ['mobilenetV2'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + group=groups, + has_bn=True, + activation='relu') + + def construct(self, x): + x = self.conv(x) + return x + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, + stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, + pad_mode='pad', padding=0, group=1, has_bn=True) + ]) + self.conv = nn.SequentialCell(layers) + self.add = P.TensorAdd() + + def construct(self, x): + out = self.conv(x) + if self.use_res_connect: + out = self.add(out, x) + return out + + +class mobilenetV2(nn.Cell): + """ + mobilenetV2 fusion architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> mobilenetV2(num_classes=1000) + """ + + def __init__(self, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(mobilenetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible( + input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible( + last_channel * max(1.0, width_mult), round_nearest) + + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU( + input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), + nn.DenseBnAct(self.out_channels, num_classes, + has_bias=True, has_bn=False) + ] if not has_dropout else + [GlobalAvgPooling(), + nn.Dropout(0.2), + nn.DenseBnAct(self.out_channels, num_classes, + has_bias=True, has_bn=False) + ]) + self.head = nn.SequentialCell(head) + + # init weights + self.init_parameters_data() + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + np.random.seed(1) + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32")) + m.weight.set_data(w) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.Conv2dBnAct): + n = m.conv.kernel_size[0] * \ + m.conv.kernel_size[1] * m.conv.out_channels + w = Tensor(np.random.normal(0, np.sqrt(2. / n), + m.conv.weight.data.shape).astype("float32")) + m.conv.weight.set_data(w) + if m.conv.bias is not None: + m.conv.bias.set_data( + Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.DenseBnAct): + m.dense.weight.set_data( + Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32"))) + if m.dense.bias is not None: + m.dense.bias.set_data( + Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32"))) diff --git a/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py b/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py new file mode 100644 index 00000000000..acb9531b9ce --- /dev/null +++ b/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py @@ -0,0 +1,123 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train Mobilenetv2_quant on Cifar10""" + + +import pytest +import numpy as np +from easydict import EasyDict as ed + +from mindspore import context +from mindspore import Tensor +from mindspore import nn +from mindspore.train.model import Model +from mindspore.train.quant import quant +from mindspore.common import set_seed + +from dataset import create_dataset +from lr_generator import get_lr +from utils import Monitor, CrossEntropyWithLabelSmooth +from mobilenetV2 import mobilenetV2 + +config_ascend_quant = ed({ + "num_classes": 10, + "image_height": 224, + "image_width": 224, + "batch_size": 200, + "step_threshold": 10, + "data_load_mode": "mindata", + "epoch_size": 1, + "start_epoch": 200, + "warmup_epochs": 1, + "lr": 0.3, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", + "quantization_aware": True, +}) + +dataset_path = "/dataset/workspace/mindspore_dataset/cifar-10-batches-bin/" + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def train_on_ascend(): + set_seed(1) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + config = config_ascend_quant + print("training configure: {}".format(config)) + + epoch_size = config.epoch_size + + # define network + network = mobilenetV2(num_classes=config.num_classes) + # define loss + if config.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config.label_smooth, num_classes=config.num_classes) + else: + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + # define dataset + dataset = create_dataset(dataset_path=dataset_path, + config=config, + repeat_num=1, + batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + + # get learning rate + lr = Tensor(get_lr(global_step=config.start_epoch * step_size, + lr_init=0, + lr_end=0, + lr_max=config.lr, + warmup_epochs=config.warmup_epochs, + total_epochs=epoch_size + config.start_epoch, + steps_per_epoch=step_size)) + + # define optimization + opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum, + config.weight_decay) + # define model + model = Model(network, loss_fn=loss, optimizer=opt) + + print("============== Starting Training ==============") + monitor = Monitor(lr_init=lr.asnumpy(), + step_threshold=config.step_threshold) + callback = [monitor] + model.train(epoch_size, dataset, callbacks=callback, + dataset_sink_mode=False) + print("============== End Training ==============") + + expect_avg_step_loss = 2.32 + avg_step_loss = np.mean(np.array(monitor.losses)) + + print("average step loss:{}".format(avg_step_loss)) + assert avg_step_loss < expect_avg_step_loss + + +if __name__ == '__main__': + train_on_ascend() diff --git a/tests/st/quantization/mobilenetv2_quant/utils.py b/tests/st/quantization/mobilenetv2_quant/utils.py new file mode 100644 index 00000000000..77124b77316 --- /dev/null +++ b/tests/st/quantization/mobilenetv2_quant/utils.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 utils""" + +import time +import numpy as np + +from mindspore.train.callback import Callback +from mindspore import Tensor +from mindspore import nn +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None, step_threshold=10): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + self.step_threshold = step_threshold + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + self.epoch_mseconds = epoch_mseconds + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.5f}]".format( + cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch + + 1, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + + if cb_params.cur_step_num == self.step_threshold: + run_context.request_stop() + + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss diff --git a/tests/st/quantization/resnet50_quant/dataset.py b/tests/st/quantization/resnet50_quant/dataset.py new file mode 100755 index 00000000000..fd4df32d9f1 --- /dev/null +++ b/tests/st/quantization/resnet50_quant/dataset.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" create train dataset. """ + + +from functools import partial + +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.vision.c_transforms as C + + +def create_dataset(dataset_path, config, repeat_num=1, batch_size=32): + """ + create a train dataset + + Args: + dataset_path(string): the path of dataset. + config(EasyDict):the basic config for training + repeat_num(int): the repeat times of dataset. Default: 1. + batch_size(int): the batch size of dataset. Default: 32. + + Returns: + dataset + """ + + load_func = partial(de.Cifar10Dataset, dataset_path) + ds = load_func(num_parallel_workers=8, shuffle=False) + + resize_height = config.image_height + resize_width = config.image_width + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + resize_op = C.Resize((resize_height, resize_width)) + normalize_op = C.Normalize(mean=mean, std=std) + changeswap_op = C.HWC2CHW() + c_trans = [resize_op, normalize_op, changeswap_op] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(operations=c_trans, input_columns="image", + num_parallel_workers=8) + ds = ds.map(operations=type_cast_op, + input_columns="label", num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds diff --git a/tests/st/quantization/resnet50_quant/lr_generator.py b/tests/st/quantization/resnet50_quant/lr_generator.py new file mode 100755 index 00000000000..fe2a971ebfc --- /dev/null +++ b/tests/st/quantization/resnet50_quant/lr_generator.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" + +import math +import numpy as np + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, + 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'poly': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / \ + float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / + (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * \ + (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * \ + (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + + return learning_rate diff --git a/tests/st/quantization/resnet50_quant/resnet_quant_manual.py b/tests/st/quantization/resnet50_quant/resnet_quant_manual.py new file mode 100644 index 00000000000..0298971c03f --- /dev/null +++ b/tests/st/quantization/resnet50_quant/resnet_quant_manual.py @@ -0,0 +1,354 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" + +import numpy as np + +import mindspore.nn as nn +import mindspore.common.initializer as weight_init +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant + + +_ema_decay = 0.999 +_symmetric = True +_fake = True +_per_channel = True + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, + group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric) + layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) + self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) + self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, stride=1, pad_mode='same', padding=0), + FakeQuantWithMinMax( + ema=True, ema_decay=_ema_decay, symmetric=False) + ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, + per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, stride=1, + pad_mode='same', padding=0) + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, + per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, stride=stride, + pad_mode='same', padding=0), + FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, + symmetric=False) + ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, + fake=_fake, + per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, + stride=stride, + pad_mode='same', + padding=0) + self.add = nn.TensorAddQuant() + self.relu = P.ReLU() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError( + "the length of layer_num, in_channels, out_channels list must be 4!") + + self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel, + symmetric=_symmetric) + self.output_fake = nn.FakeQuantWithMinMax( + ema=True, ema_decay=_ema_decay) + + # init weights + self._initialize_weights() + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + out = self.output_fake(out) + return out + + def _initialize_weights(self): + + self.init_parameters_data() + for _, m in self.cells_and_names(): + np.random.seed(1) + + if isinstance(m, nn.Conv2dBnFoldQuant): + m.weight.set_data(weight_init.initializer(weight_init.Normal(), + m.weight.shape, + m.weight.dtype)) + elif isinstance(m, nn.DenseQuant): + m.weight.set_data(weight_init.initializer(weight_init.Normal(), + m.weight.shape, + m.weight.dtype)) + elif isinstance(m, nn.Conv2dBnWithoutFoldQuant): + m.weight.set_data(weight_init.initializer(weight_init.Normal(), + m.weight.shape, + m.weight.dtype)) + + +def resnet50_quant(class_num=10): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50_quant(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + + +def resnet101_quant(class_num=1001): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/tests/st/quantization/resnet50_quant/test_resnet50_quant.py b/tests/st/quantization/resnet50_quant/test_resnet50_quant.py new file mode 100755 index 00000000000..3bac6a13d53 --- /dev/null +++ b/tests/st/quantization/resnet50_quant/test_resnet50_quant.py @@ -0,0 +1,131 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train Resnet50_quant on Cifar10""" + +import pytest +import numpy as np +from easydict import EasyDict as ed + +from mindspore import context +from mindspore import Tensor +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore.train.quant import quant +from mindspore import set_seed + +from resnet_quant_manual import resnet50_quant +from dataset import create_dataset +from lr_generator import get_lr +from utils import Monitor, CrossEntropy + + +config_quant = ed({ + "class_num": 10, + "batch_size": 128, + "step_threshold": 20, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 1, + "pretrained_epoch_size": 90, + "buffer_size": 1000, + "image_height": 224, + "image_width": 224, + "data_load_mode": "mindata", + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 50, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0, + "lr_max": 0.005, +}) + +dataset_path = "/dataset/workspace/mindspore_dataset/cifar-10-batches-bin/" + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def train_on_ascend(): + set_seed(1) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + config = config_quant + print("training configure: {}".format(config)) + epoch_size = config.epoch_size + + # define network + net = resnet50_quant(class_num=config.class_num) + net.set_train(True) + + # define loss + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy( + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + #loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + # define dataset + dataset = create_dataset(dataset_path=dataset_path, + config=config, + repeat_num=1, + batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + # convert fusion network to quantization aware network + net = quant.convert_quant_network(net, + bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + + # get learning rate + lr = Tensor(get_lr(lr_init=config.lr_init, + lr_end=0.0, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, + steps_per_epoch=step_size, + lr_decay_mode='cosine')) + + # define optimization + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + config.weight_decay, config.loss_scale) + + # define model + #model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) + model = Model(net, loss_fn=loss, optimizer=opt) + + print("============== Starting Training ==============") + monitor = Monitor(lr_init=lr.asnumpy(), + step_threshold=config.step_threshold) + + callbacks = [monitor] + model.train(epoch_size, dataset, callbacks=callbacks, + dataset_sink_mode=False) + print("============== End Training ==============") + + expect_avg_step_loss = 2.40 + avg_step_loss = np.mean(np.array(monitor.losses)) + + print("average step loss:{}".format(avg_step_loss)) + assert avg_step_loss < expect_avg_step_loss + + +if __name__ == '__main__': + train_on_ascend() diff --git a/tests/st/quantization/resnet50_quant/utils.py b/tests/st/quantization/resnet50_quant/utils.py new file mode 100644 index 00000000000..5711e126c71 --- /dev/null +++ b/tests/st/quantization/resnet50_quant/utils.py @@ -0,0 +1,105 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Resnet50 utils""" + +import time +import numpy as np + +from mindspore.train.callback import Callback +from mindspore import Tensor +from mindspore import nn +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None, step_threshold=10): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + self.step_threshold = step_threshold + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + self.epoch_mseconds = epoch_mseconds + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:8.6f}], time:[{:5.3f}], lr:[{:5.5f}]".format( + cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch + + 1, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + + if cb_params.cur_step_num == self.step_threshold: + run_context.request_stop() + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + + def __init__(self, smooth_factor=0, num_classes=1001): + super(CrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logit, label): + one_hot_label = self.onehot(label, F.shape( + logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, one_hot_label) + loss = self.mean(loss, 0) + return loss