forked from mindspore-Ecosystem/mindspore
!6639 add monilenetv2_quant and resnet50_quant st
Merge pull request !6639 from hwjiaorui/master
This commit is contained in:
commit
3e885c0bc1
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" create train dataset. """
|
||||
|
||||
from functools import partial
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
|
||||
|
||||
def create_dataset(dataset_path, config, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
create a train dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
config(EasyDict):the basic config for training
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
batch_size(int): the batch size of dataset. Default: 32.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
load_func = partial(ds.Cifar10Dataset, dataset_path)
|
||||
cifar_ds = load_func(num_parallel_workers=8, shuffle=False)
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
# define map operations
|
||||
# interpolation default BILINEAR
|
||||
resize_op = C.Resize((resize_height, resize_width))
|
||||
rescale_op = C.Rescale(rescale, shift)
|
||||
normalize_op = C.Normalize(
|
||||
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
changeswap_op = C.HWC2CHW()
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
c_trans = [resize_op, rescale_op, normalize_op, changeswap_op]
|
||||
|
||||
# apply map operations on images
|
||||
cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans)
|
||||
|
||||
# apply batch operations
|
||||
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
cifar_ds = cifar_ds.repeat(repeat_num)
|
||||
|
||||
return cifar_ds
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_end + \
|
||||
(lr_max - lr_end) * \
|
||||
(1. + math.cos(math.pi * (i - warmup_steps) /
|
||||
(total_steps - warmup_steps))) / 2.
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,263 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MobileNetV2 Quant model define"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
|
||||
__all__ = ['mobilenetV2']
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class GlobalAvgPooling(nn.Cell):
|
||||
"""
|
||||
Global avg pooling definition.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> GlobalAvgPooling()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GlobalAvgPooling, self).__init__()
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.mean(x, (2, 3))
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Cell):
|
||||
"""
|
||||
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
|
||||
"""
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size,
|
||||
stride=stride,
|
||||
pad_mode='pad',
|
||||
padding=padding,
|
||||
group=groups,
|
||||
has_bn=True,
|
||||
activation='relu')
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
"""
|
||||
Mobilenetv2 residual block definition.
|
||||
|
||||
Args:
|
||||
inp (int): Input channel.
|
||||
oup (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
expand_ratio (int): expand ration of input channel
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, 1, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim,
|
||||
stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1,
|
||||
pad_mode='pad', padding=0, group=1, has_bn=True)
|
||||
])
|
||||
self.conv = nn.SequentialCell(layers)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
out = self.add(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class mobilenetV2(nn.Cell):
|
||||
"""
|
||||
mobilenetV2 fusion architecture.
|
||||
|
||||
Args:
|
||||
class_num (Cell): number of classes.
|
||||
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
|
||||
has_dropout (bool): Is dropout used. Default is false
|
||||
inverted_residual_setting (list): Inverted residual settings. Default is None
|
||||
round_nearest (list): Channel round to . Default is 8
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> mobilenetV2(num_classes=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, width_mult=1.,
|
||||
has_dropout=False, inverted_residual_setting=None, round_nearest=8):
|
||||
super(mobilenetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
# setting of inverted residual blocks
|
||||
self.cfgs = inverted_residual_setting
|
||||
if inverted_residual_setting is None:
|
||||
self.cfgs = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(
|
||||
input_channel * width_mult, round_nearest)
|
||||
self.out_channels = _make_divisible(
|
||||
last_channel * max(1.0, width_mult), round_nearest)
|
||||
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in self.cfgs:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(
|
||||
block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(
|
||||
input_channel, self.out_channels, kernel_size=1))
|
||||
# make it nn.CellList
|
||||
self.features = nn.SequentialCell(features)
|
||||
# mobilenet head
|
||||
head = ([GlobalAvgPooling(),
|
||||
nn.DenseBnAct(self.out_channels, num_classes,
|
||||
has_bias=True, has_bn=False)
|
||||
] if not has_dropout else
|
||||
[GlobalAvgPooling(),
|
||||
nn.Dropout(0.2),
|
||||
nn.DenseBnAct(self.out_channels, num_classes,
|
||||
has_bias=True, has_bn=False)
|
||||
])
|
||||
self.head = nn.SequentialCell(head)
|
||||
|
||||
# init weights
|
||||
self.init_parameters_data()
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
Initialize weights.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Examples:
|
||||
>>> _initialize_weights()
|
||||
"""
|
||||
self.init_parameters_data()
|
||||
for _, m in self.cells_and_names():
|
||||
np.random.seed(1)
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
w = Tensor(np.random.normal(0, np.sqrt(2. / n),
|
||||
m.weight.data.shape).astype("float32"))
|
||||
m.weight.set_data(w)
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.Conv2dBnAct):
|
||||
n = m.conv.kernel_size[0] * \
|
||||
m.conv.kernel_size[1] * m.conv.out_channels
|
||||
w = Tensor(np.random.normal(0, np.sqrt(2. / n),
|
||||
m.conv.weight.data.shape).astype("float32"))
|
||||
m.conv.weight.set_data(w)
|
||||
if m.conv.bias is not None:
|
||||
m.conv.bias.set_data(
|
||||
Tensor(np.zeros(m.conv.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.gamma.set_data(
|
||||
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
||||
m.beta.set_data(
|
||||
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(np.random.normal(
|
||||
0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.DenseBnAct):
|
||||
m.dense.weight.set_data(
|
||||
Tensor(np.random.normal(0, 0.01, m.dense.weight.data.shape).astype("float32")))
|
||||
if m.dense.bias is not None:
|
||||
m.dense.bias.set_data(
|
||||
Tensor(np.zeros(m.dense.bias.data.shape, dtype="float32")))
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train Mobilenetv2_quant on Cifar10"""
|
||||
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from dataset import create_dataset
|
||||
from lr_generator import get_lr
|
||||
from utils import Monitor, CrossEntropyWithLabelSmooth
|
||||
from mobilenetV2 import mobilenetV2
|
||||
|
||||
config_ascend_quant = ed({
|
||||
"num_classes": 10,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 200,
|
||||
"step_threshold": 10,
|
||||
"data_load_mode": "mindata",
|
||||
"epoch_size": 1,
|
||||
"start_epoch": 200,
|
||||
"warmup_epochs": 1,
|
||||
"lr": 0.3,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 300,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"quantization_aware": True,
|
||||
})
|
||||
|
||||
dataset_path = "/dataset/workspace/mindspore_dataset/cifar-10-batches-bin/"
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def train_on_ascend():
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
config = config_ascend_quant
|
||||
print("training configure: {}".format(config))
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
|
||||
# define network
|
||||
network = mobilenetV2(num_classes=config.num_classes)
|
||||
# define loss
|
||||
if config.label_smooth > 0:
|
||||
loss = CrossEntropyWithLabelSmooth(
|
||||
smooth_factor=config.label_smooth, num_classes=config.num_classes)
|
||||
else:
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define dataset
|
||||
dataset = create_dataset(dataset_path=dataset_path,
|
||||
config=config,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
|
||||
# get learning rate
|
||||
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
|
||||
lr_init=0,
|
||||
lr_end=0,
|
||||
lr_max=config.lr,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=epoch_size + config.start_epoch,
|
||||
steps_per_epoch=step_size))
|
||||
|
||||
# define optimization
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay)
|
||||
# define model
|
||||
model = Model(network, loss_fn=loss, optimizer=opt)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
monitor = Monitor(lr_init=lr.asnumpy(),
|
||||
step_threshold=config.step_threshold)
|
||||
callback = [monitor]
|
||||
model.train(epoch_size, dataset, callbacks=callback,
|
||||
dataset_sink_mode=False)
|
||||
print("============== End Training ==============")
|
||||
|
||||
expect_avg_step_loss = 2.32
|
||||
avg_step_loss = np.mean(np.array(monitor.losses))
|
||||
|
||||
print("average step loss:{}".format(avg_step_loss))
|
||||
assert avg_step_loss < expect_avg_step_loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_on_ascend()
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MobileNetV2 utils"""
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None, step_threshold=10):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
self.step_threshold = step_threshold
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds,
|
||||
per_step_mseconds,
|
||||
np.mean(self.losses)))
|
||||
self.epoch_mseconds = epoch_mseconds
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
step_loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
|
||||
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.5f}]".format(
|
||||
cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch +
|
||||
1, cb_params.batch_num, step_loss,
|
||||
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
|
||||
|
||||
if cb_params.cur_step_num == self.step_threshold:
|
||||
run_context.request_stop()
|
||||
|
||||
|
||||
class CrossEntropyWithLabelSmooth(_Loss):
|
||||
"""
|
||||
CrossEntropyWith LabelSmooth.
|
||||
|
||||
Args:
|
||||
smooth_factor (float): smooth factor, default=0.
|
||||
num_classes (int): num classes
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Examples:
|
||||
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropyWithLabelSmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor /
|
||||
(num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1],
|
||||
self.on_value, self.off_value)
|
||||
out_loss = self.ce(logit, one_hot_label)
|
||||
out_loss = self.mean(out_loss, 0)
|
||||
return out_loss
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" create train dataset. """
|
||||
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
||||
|
||||
def create_dataset(dataset_path, config, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
create a train dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
config(EasyDict):the basic config for training
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
batch_size(int): the batch size of dataset. Default: 32.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
load_func = partial(de.Cifar10Dataset, dataset_path)
|
||||
ds = load_func(num_parallel_workers=8, shuffle=False)
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
resize_op = C.Resize((resize_height, resize_width))
|
||||
normalize_op = C.Normalize(mean=mean, std=std)
|
||||
changeswap_op = C.HWC2CHW()
|
||||
c_trans = [resize_op, normalize_op, changeswap_op]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(operations=c_trans, input_columns="image",
|
||||
num_parallel_workers=8)
|
||||
ds = ds.map(operations=type_cast_op,
|
||||
input_columns="label", num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps,
|
||||
0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / \
|
||||
float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) /
|
||||
(float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * \
|
||||
(1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * \
|
||||
(i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,354 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = True
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Cell):
|
||||
"""
|
||||
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
|
||||
"""
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
|
||||
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(
|
||||
ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
padding=0)
|
||||
self.add = nn.TensorAddQuant()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError(
|
||||
"the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
|
||||
symmetric=_symmetric)
|
||||
self.output_fake = nn.FakeQuantWithMinMax(
|
||||
ema=True, ema_decay=_ema_decay)
|
||||
|
||||
# init weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride)
|
||||
layers.append(resnet_block)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1)
|
||||
layers.append(resnet_block)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
out = self.output_fake(out)
|
||||
return out
|
||||
|
||||
def _initialize_weights(self):
|
||||
|
||||
self.init_parameters_data()
|
||||
for _, m in self.cells_and_names():
|
||||
np.random.seed(1)
|
||||
|
||||
if isinstance(m, nn.Conv2dBnFoldQuant):
|
||||
m.weight.set_data(weight_init.initializer(weight_init.Normal(),
|
||||
m.weight.shape,
|
||||
m.weight.dtype))
|
||||
elif isinstance(m, nn.DenseQuant):
|
||||
m.weight.set_data(weight_init.initializer(weight_init.Normal(),
|
||||
m.weight.shape,
|
||||
m.weight.dtype))
|
||||
elif isinstance(m, nn.Conv2dBnWithoutFoldQuant):
|
||||
m.weight.set_data(weight_init.initializer(weight_init.Normal(),
|
||||
m.weight.shape,
|
||||
m.weight.dtype))
|
||||
|
||||
|
||||
def resnet50_quant(class_num=10):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50_quant(10)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
|
||||
|
||||
def resnet101_quant(class_num=1001):
|
||||
"""
|
||||
Get ResNet101 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet101 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet101(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 23, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train Resnet50_quant on Cifar10"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore import set_seed
|
||||
|
||||
from resnet_quant_manual import resnet50_quant
|
||||
from dataset import create_dataset
|
||||
from lr_generator import get_lr
|
||||
from utils import Monitor, CrossEntropy
|
||||
|
||||
|
||||
config_quant = ed({
|
||||
"class_num": 10,
|
||||
"batch_size": 128,
|
||||
"step_threshold": 20,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 1,
|
||||
"pretrained_epoch_size": 90,
|
||||
"buffer_size": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"data_load_mode": "mindata",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 50,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 0,
|
||||
"lr_decay_mode": "cosine",
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0,
|
||||
"lr_max": 0.005,
|
||||
})
|
||||
|
||||
dataset_path = "/dataset/workspace/mindspore_dataset/cifar-10-batches-bin/"
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def train_on_ascend():
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
config = config_quant
|
||||
print("training configure: {}".format(config))
|
||||
epoch_size = config.epoch_size
|
||||
|
||||
# define network
|
||||
net = resnet50_quant(class_num=config.class_num)
|
||||
net.set_train(True)
|
||||
|
||||
# define loss
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
#loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
# define dataset
|
||||
dataset = create_dataset(dataset_path=dataset_path,
|
||||
config=config,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
net = quant.convert_quant_network(net,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
|
||||
# get learning rate
|
||||
lr = Tensor(get_lr(lr_init=config.lr_init,
|
||||
lr_end=0.0,
|
||||
lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=step_size,
|
||||
lr_decay_mode='cosine'))
|
||||
|
||||
# define optimization
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
|
||||
# define model
|
||||
#model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
|
||||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
monitor = Monitor(lr_init=lr.asnumpy(),
|
||||
step_threshold=config.step_threshold)
|
||||
|
||||
callbacks = [monitor]
|
||||
model.train(epoch_size, dataset, callbacks=callbacks,
|
||||
dataset_sink_mode=False)
|
||||
print("============== End Training ==============")
|
||||
|
||||
expect_avg_step_loss = 2.40
|
||||
avg_step_loss = np.mean(np.array(monitor.losses))
|
||||
|
||||
print("average step loss:{}".format(avg_step_loss))
|
||||
assert avg_step_loss < expect_avg_step_loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_on_ascend()
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Resnet50 utils"""
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None, step_threshold=10):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
self.step_threshold = step_threshold
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds,
|
||||
per_step_mseconds,
|
||||
np.mean(self.losses)))
|
||||
self.epoch_mseconds = epoch_mseconds
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
step_loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
|
||||
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:8.6f}], time:[{:5.3f}], lr:[{:5.5f}]".format(
|
||||
cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch +
|
||||
1, cb_params.batch_num, step_loss,
|
||||
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
|
||||
|
||||
if cb_params.cur_step_num == self.step_threshold:
|
||||
run_context.request_stop()
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
|
||||
def __init__(self, smooth_factor=0, num_classes=1001):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor /
|
||||
(num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label, F.shape(
|
||||
logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
Loading…
Reference in New Issue