From a71868f1e5877e19a39f73bb5a7e8727614d1942 Mon Sep 17 00:00:00 2001
From: chenfei <chefei52@huawei.com>
Date: Thu, 20 Aug 2020 19:36:52 +0800
Subject: [PATCH] add ci for quant

add ci for quant
---
 .../official/cv/lenet_quant/eval_quant.py     |   1 -
 model_zoo/official/cv/resnet50_quant/eval.py  |  15 +-
 .../official/cv/resnet50_quant/src/dataset.py |   4 +-
 tests/st/quantization/lenet_quant/config.py   |  44 ++++++
 tests/st/quantization/lenet_quant/dataset.py  |  60 ++++++++
 tests/st/quantization/lenet_quant/lenet.py    |  79 ++++++++++
 .../quantization/lenet_quant/lenet_fusion.py  |  58 ++++++++
 .../lenet_quant/test_lenet_quant.py           | 136 ++++++++++++++++++
 8 files changed, 386 insertions(+), 11 deletions(-)
 create mode 100644 tests/st/quantization/lenet_quant/config.py
 create mode 100644 tests/st/quantization/lenet_quant/dataset.py
 create mode 100644 tests/st/quantization/lenet_quant/lenet.py
 create mode 100644 tests/st/quantization/lenet_quant/lenet_fusion.py
 create mode 100644 tests/st/quantization/lenet_quant/test_lenet_quant.py

diff --git a/model_zoo/official/cv/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py
index 4849d01daf9..f545a8a23a8 100644
--- a/model_zoo/official/cv/lenet_quant/eval_quant.py
+++ b/model_zoo/official/cv/lenet_quant/eval_quant.py
@@ -45,7 +45,6 @@ args = parser.parse_args()
 if __name__ == "__main__":
     context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
     ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
-    step_size = ds_eval.get_dataset_size()
 
     # define fusion network
     network = LeNet5Fusion(cfg.num_classes)
diff --git a/model_zoo/official/cv/resnet50_quant/eval.py b/model_zoo/official/cv/resnet50_quant/eval.py
index 0395e38b601..9eb3ce3520a 100755
--- a/model_zoo/official/cv/resnet50_quant/eval.py
+++ b/model_zoo/official/cv/resnet50_quant/eval.py
@@ -17,7 +17,7 @@
 import os
 import argparse
 
-from src.config import quant_set, config_quant, config_noquant
+from src.config import config_quant
 from src.dataset import create_dataset
 from src.crossentropy import CrossEntropy
 from models.resnet_quant import resnet50_quant
@@ -34,7 +34,7 @@ parser.add_argument('--device_target', type=str, default='Ascend', help='Device
 args_opt = parser.parse_args()
 
 context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
-config = config_quant if quant_set.quantization_aware else config_noquant
+config = config_quant
 
 if args_opt.device_target == "Ascend":
     device_id = int(os.getenv('DEVICE_ID'))
@@ -43,12 +43,11 @@ if args_opt.device_target == "Ascend":
 if __name__ == '__main__':
     # define fusion network
     net = resnet50_quant(class_num=config.class_num)
-    if quant_set.quantization_aware:
-        # convert fusion network to quantization aware network
-        net = quant.convert_quant_network(net,
-                                          bn_fold=True,
-                                          per_channel=[True, False],
-                                          symmetric=[True, False])
+    # convert fusion network to quantization aware network
+    net = quant.convert_quant_network(net,
+                                      bn_fold=True,
+                                      per_channel=[True, False],
+                                      symmetric=[True, False])
     # define network loss
     if not config.use_label_smooth:
         config.label_smooth_factor = 0.0
diff --git a/model_zoo/official/cv/resnet50_quant/src/dataset.py b/model_zoo/official/cv/resnet50_quant/src/dataset.py
index 73c07800900..3c35adeaaa0 100755
--- a/model_zoo/official/cv/resnet50_quant/src/dataset.py
+++ b/model_zoo/official/cv/resnet50_quant/src/dataset.py
@@ -23,9 +23,9 @@ import mindspore.dataset.transforms.vision.c_transforms as C
 import mindspore.dataset.transforms.c_transforms as C2
 import mindspore.dataset.transforms.vision.py_transforms as P
 from mindspore.communication.management import init, get_rank, get_group_size
-from src.config import quant_set, config_quant, config_noquant
+from src.config import config_quant
 
-config = config_quant if quant_set.quantization_aware else config_noquant
+config = config_quant
 
 
 def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
diff --git a/tests/st/quantization/lenet_quant/config.py b/tests/st/quantization/lenet_quant/config.py
new file mode 100644
index 00000000000..7c4f5a54b7e
--- /dev/null
+++ b/tests/st/quantization/lenet_quant/config.py
@@ -0,0 +1,44 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in test_lenet_quant.py
+"""
+
+from easydict import EasyDict as edict
+
+nonquant_cfg = edict({
+    'num_classes': 10,
+    'lr': 0.01,
+    'momentum': 0.9,
+    'epoch_size': 10,
+    'batch_size': 32,
+    'buffer_size': 1000,
+    'image_height': 32,
+    'image_width': 32,
+    'save_checkpoint_steps': 1875,
+    'keep_checkpoint_max': 10,
+})
+
+quant_cfg = edict({
+    'num_classes': 10,
+    'lr': 0.01,
+    'momentum': 0.9,
+    'epoch_size': 10,
+    'batch_size': 64,
+    'buffer_size': 1000,
+    'image_height': 32,
+    'image_width': 32,
+    'keep_checkpoint_max': 10,
+})
diff --git a/tests/st/quantization/lenet_quant/dataset.py b/tests/st/quantization/lenet_quant/dataset.py
new file mode 100644
index 00000000000..cef69734839
--- /dev/null
+++ b/tests/st/quantization/lenet_quant/dataset.py
@@ -0,0 +1,60 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Produce the dataset
+"""
+
+import mindspore.dataset as ds
+import mindspore.dataset.transforms.vision.c_transforms as CV
+import mindspore.dataset.transforms.c_transforms as C
+from mindspore.dataset.transforms.vision import Inter
+from mindspore.common import dtype as mstype
+
+
+def create_dataset(data_path, batch_size=32, repeat_size=1,
+                   num_parallel_workers=1):
+    """
+    create dataset for train or test
+    """
+    # define dataset
+    mnist_ds = ds.MnistDataset(data_path)
+
+    resize_height, resize_width = 32, 32
+    rescale = 1.0 / 255.0
+    shift = 0.0
+    rescale_nml = 1 / 0.3081
+    shift_nml = -1 * 0.1307 / 0.3081
+
+    # define map operations
+    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
+    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
+    rescale_op = CV.Rescale(rescale, shift)
+    hwc2chw_op = CV.HWC2CHW()
+    type_cast_op = C.TypeCast(mstype.int32)
+
+    # apply map operations on images
+    mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
+    mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
+    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
+    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
+    mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
+
+    # apply DatasetOps
+    buffer_size = 10000
+    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
+    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
+    mnist_ds = mnist_ds.repeat(repeat_size)
+
+    return mnist_ds
diff --git a/tests/st/quantization/lenet_quant/lenet.py b/tests/st/quantization/lenet_quant/lenet.py
new file mode 100644
index 00000000000..42444100073
--- /dev/null
+++ b/tests/st/quantization/lenet_quant/lenet.py
@@ -0,0 +1,79 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LeNet."""
+import mindspore.nn as nn
+from mindspore.common.initializer import TruncatedNormal
+
+
+def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
+    """weight initial for conv layer"""
+    weight = weight_variable()
+    return nn.Conv2d(in_channels, out_channels,
+                     kernel_size=kernel_size, stride=stride, padding=padding,
+                     weight_init=weight, has_bias=False, pad_mode="valid")
+
+
+def fc_with_initialize(input_channels, out_channels):
+    """weight initial for fc layer"""
+    weight = weight_variable()
+    bias = weight_variable()
+    return nn.Dense(input_channels, out_channels, weight, bias)
+
+
+def weight_variable():
+    """weight initial"""
+    return TruncatedNormal(0.02)
+
+
+class LeNet5(nn.Cell):
+    """
+    Lenet network
+
+    Args:
+        num_class (int): Num classes. Default: 10.
+
+    Returns:
+        Tensor, output tensor
+    Examples:
+        >>> LeNet(num_class=10)
+
+    """
+
+    def __init__(self, num_class=10, channel=1):
+        super(LeNet5, self).__init__()
+        self.num_class = num_class
+        self.conv1 = conv(channel, 6, 5)
+        self.conv2 = conv(6, 16, 5)
+        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
+        self.fc2 = fc_with_initialize(120, 84)
+        self.fc3 = fc_with_initialize(84, self.num_class)
+        self.relu = nn.ReLU()
+        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.flatten = nn.Flatten()
+
+    def construct(self, x):
+        x = self.conv1(x)
+        x = self.relu(x)
+        x = self.max_pool2d(x)
+        x = self.conv2(x)
+        x = self.relu(x)
+        x = self.max_pool2d(x)
+        x = self.flatten(x)
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        x = self.relu(x)
+        x = self.fc3(x)
+        return x
diff --git a/tests/st/quantization/lenet_quant/lenet_fusion.py b/tests/st/quantization/lenet_quant/lenet_fusion.py
new file mode 100644
index 00000000000..88b35935027
--- /dev/null
+++ b/tests/st/quantization/lenet_quant/lenet_fusion.py
@@ -0,0 +1,58 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LeNet."""
+import mindspore.nn as nn
+
+
+class LeNet5(nn.Cell):
+    """
+    Lenet network
+
+    Args:
+        num_class (int): Num classes. Default: 10.
+
+    Returns:
+        Tensor, output tensor
+    Examples:
+        >>> LeNet(num_class=10)
+
+    """
+
+    def __init__(self, num_class=10, channel=1):
+        super(LeNet5, self).__init__()
+        self.type = "fusion"
+        self.num_class = num_class
+
+        # change `nn.Conv2d` to `nn.Conv2dBnAct`
+        self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
+        self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
+        # change `nn.Dense` to `nn.DenseBnAct`
+        self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
+        self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
+        self.fc3 = nn.DenseBnAct(84, self.num_class)
+
+        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.flatten = nn.Flatten()
+
+    def construct(self, x):
+        x = self.conv1(x)
+        x = self.max_pool2d(x)
+        x = self.conv2(x)
+        x = self.max_pool2d(x)
+        x = self.flatten(x)
+        x = self.fc1(x)
+        x = self.fc2(x)
+        x = self.fc3(x)
+        return x
diff --git a/tests/st/quantization/lenet_quant/test_lenet_quant.py b/tests/st/quantization/lenet_quant/test_lenet_quant.py
new file mode 100644
index 00000000000..361aa1abf62
--- /dev/null
+++ b/tests/st/quantization/lenet_quant/test_lenet_quant.py
@@ -0,0 +1,136 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+train and infer lenet quantization network
+"""
+
+import os
+import pytest
+from mindspore import context
+import mindspore.nn as nn
+from mindspore.nn.metrics import Accuracy
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.train import Model
+from mindspore.train.quant import quant
+from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
+from dataset import create_dataset
+from config import nonquant_cfg, quant_cfg
+from lenet import LeNet5
+from lenet_fusion import LeNet5 as LeNet5Fusion
+
+device_target = 'GPU'
+data_path = "/home/workspace/mindspore_dataset/mnist"
+
+
+def train_lenet():
+    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
+    cfg = nonquant_cfg
+    ds_train = create_dataset(os.path.join(data_path, "train"),
+                              cfg.batch_size)
+
+    network = LeNet5(cfg.num_classes)
+    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
+    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
+    config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
+                                 keep_checkpoint_max=cfg.keep_checkpoint_max)
+    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
+    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
+
+    print("============== Starting Training Lenet==============")
+    model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
+                dataset_sink_mode=True)
+
+
+def train_lenet_quant():
+    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
+    cfg = quant_cfg
+    ckpt_path = './checkpoint_lenet-10_1875.ckpt'
+    ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
+    step_size = ds_train.get_dataset_size()
+
+    # define fusion network
+    network = LeNet5Fusion(cfg.num_classes)
+
+    # load quantization aware network checkpoint
+    param_dict = load_checkpoint(ckpt_path)
+    load_nonquant_param_into_quant_net(network, param_dict)
+
+    # convert fusion network to quantization aware network
+    network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False],
+                                          symmetric=[False, False])
+
+    # define network loss
+    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
+    # define network optimization
+    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+
+    # call back and monitor
+    config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
+                                   keep_checkpoint_max=cfg.keep_checkpoint_max)
+    ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
+
+    # define model
+    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
+
+    print("============== Starting Training ==============")
+    model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
+                dataset_sink_mode=True)
+    print("============== End Training ==============")
+
+
+def eval_quant():
+    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
+    cfg = quant_cfg
+    ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
+    ckpt_path = './checkpoint_lenet_1-10_937.ckpt'
+    # define fusion network
+    network = LeNet5Fusion(cfg.num_classes)
+    # convert fusion network to quantization aware network
+    network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
+                                          per_channel=[True, False])
+
+    # define loss
+    net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
+    # define network optimization
+    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+
+    # call back and monitor
+    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
+
+    # load quantization aware network checkpoint
+    param_dict = load_checkpoint(ckpt_path)
+    not_load_param = load_param_into_net(network, param_dict)
+    if not_load_param:
+        raise ValueError("Load param into net fail!")
+
+    print("============== Starting Testing ==============")
+    acc = model.eval(ds_eval, dataset_sink_mode=True)
+    print("============== {} ==============".format(acc))
+    assert acc['Accuracy'] > 0.98
+
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_lenet_quant():
+    train_lenet()
+    train_lenet_quant()
+    eval_quant()
+
+
+if __name__ == "__main__":
+    train_lenet_quant()