[Quant][lenet]eval should set bn_fold as true

This commit is contained in:
chenfei 2020-08-20 11:35:31 +08:00
parent a27e6f5769
commit 2a5d90dc15
15 changed files with 55 additions and 350 deletions

View File

@ -250,3 +250,35 @@ def without_fold_batchnorm(weight, cell_quant):
weight = weight * _gamma / _sigma
bias = beta - gamma * mean / sigma
return weight, bias
def load_nonquant_param_into_quant_net(quant_model, params_dict):
"""
load fp32 model parameters to quantization model.
Args:
quant_model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict = {
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
'moving_variance': iter(
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
}
for name, param in quant_model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
value_param = next(iterable_dict[key_name], None)
if value_param is not None:
param.set_parameter_data(value_param[1].data)
print(f'init model param {name} with checkpoint param {value_param[0]}')

View File

@ -308,6 +308,7 @@ def load_param_into_net(net, parameter_dict):
logger.debug("%s", param_name)
logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load)))
return param_not_load
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):

View File

@ -93,65 +93,6 @@ Get the MNIST from scratch dataset.
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size()
```
### Train model
Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
```Python
# Define the network
network = LeNet5Fusion(cfg.num_classes)
# Define the loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# Define optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# Define model using loss and optimization.
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
```
Now we can start training.
```Python
model.train(cfg['epoch_size'], ds_train,
callbacks=[time_cb, ckpoint_cb, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
```
After all the following we will get the loss value of each step as following:
```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ...
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```
Also, you can just run this command instead.
```python
python train.py --data_path MNIST_Data --device_target Ascend
```
### Evaluate fusion model
After training epoch stop. We can get the fusion model checkpoint file like `checkpoint_lenet.ckpt`. Meanwhile, we can evaluate this fusion model.
```python
python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
```
The top1 accuracy would display on shell.
```bash
>>> Accuracy: 98.53.
```
## Train quantization aware model

View File

@ -1,65 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## eval lenet example ########################
eval lenet according to model file:
python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
"""
import os
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args()
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load check point into network
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
print("============== Starting Testing ==============")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))

View File

@ -63,7 +63,9 @@ if __name__ == "__main__":
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
not_load_param = load_param_into_net(network, param_dict)
if not_load_param:
raise ValueError("Load param into net fail!")
print("============== Starting Testing ==============")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)

View File

@ -1,64 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import mindspore.nn as nn
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.bn2 = nn.BatchNorm2d(16)
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x

View File

@ -36,8 +36,8 @@ class LeNet5(nn.Cell):
self.num_class = num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct`
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', has_bn=True, activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', has_bn=True, activation='relu')
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
# change `nn.Dense` to `nn.DenseBnAct`
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')

View File

@ -1,68 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt) :
python train.py --data_path /YourDataPath
"""
import os
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion
from src.loss_monitor import LossMonitor
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args()
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============")

View File

@ -22,11 +22,12 @@ import os
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion
@ -54,10 +55,11 @@ if __name__ == "__main__":
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=900, per_channel=[True, False], symmetric=[False, False])
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False],
symmetric=[False, False])
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")

View File

@ -68,7 +68,9 @@ if __name__ == '__main__':
# load checkpoint
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(network, param_dict)
not_load_param = load_param_into_net(network, param_dict)
if not_load_param:
raise ValueError("Load param into net fail!")
network.set_train(False)
# define model

View File

@ -25,39 +25,6 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
def _load_param_into_net(model, params_dict):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict = {
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
'moving_variance': iter(
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
}
for name, param in model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
value_param = next(iterable_dict[key_name], None)
if value_param is not None:
param.set_parameter_data(value_param[1].data)
print(f'init model param {name} with checkpoint param {value_param[0]}')
class Monitor(Callback):
"""
Monitor loss and time.

View File

@ -28,6 +28,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train.quant import quant
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
import mindspore.dataset.engine as de
from src.dataset import create_dataset
@ -35,7 +36,6 @@ from src.lr_generator import get_lr
from src.utils import Monitor, CrossEntropyWithLabelSmooth
from src.config import config_ascend_quant, config_gpu_quant
from src.mobilenetV2 import mobilenetV2
from src.utils import _load_param_into_net
random.seed(1)
np.random.seed(1)
@ -101,7 +101,7 @@ def train_on_ascend():
# load pre trained ckpt
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
_load_param_into_net(network, param_dict)
load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network,
bn_fold=True,
@ -163,7 +163,7 @@ def train_on_gpu():
# resume
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
_load_param_into_net(network, param_dict)
load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network,

View File

@ -20,12 +20,11 @@ import argparse
from src.config import quant_set, config_quant, config_noquant
from src.dataset import create_dataset
from src.crossentropy import CrossEntropy
from src.utils import _load_param_into_net
from models.resnet_quant import resnet50_quant
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant
parser = argparse.ArgumentParser(description='Image classification')
@ -66,7 +65,9 @@ if __name__ == '__main__':
# load checkpoint
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
_load_param_into_net(net, param_dict)
not_load_param = load_param_into_net(net, param_dict)
if not_load_param:
raise ValueError("Load param into net fail!")
net.set_train(False)
# define model

View File

@ -1,46 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utils script"""
def _load_param_into_net(model, params_dict):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict = {
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
'moving_variance': iter(
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
}
for name, param in model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
continue
value_param = next(iterable_dict[key_name], None)
if value_param is not None:
param.set_parameter_data(value_param[1].data)
print(f'init model param {name} with checkpoint param {value_param[0]}')

View File

@ -26,6 +26,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint
from mindspore.train.quant import quant
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
from mindspore.communication.management import init
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
@ -35,7 +36,6 @@ from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import config_quant
from src.crossentropy import CrossEntropy
from src.utils import _load_param_into_net
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
@ -85,7 +85,7 @@ if __name__ == '__main__':
# weight init and load checkpoint file
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
_load_param_into_net(net, param_dict)
load_nonquant_param_into_quant_net(net, param_dict)
epoch_size = config.epoch_size - config.pretrained_epoch_size
else:
for _, cell in net.cells_and_names():