forked from mindspore-Ecosystem/mindspore
!2943 [quant]export bug fix
Merge pull request !2943 from vlne-v1/quant_export_bugfix
This commit is contained in:
commit
d3ec05d716
|
@ -328,6 +328,9 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
|
|||
x = cnode->input(1);
|
||||
count += 1;
|
||||
}
|
||||
if (x->isa<Parameter>()) {
|
||||
fake_quant_table[weight_name] = std::make_pair(nullptr, "input");
|
||||
}
|
||||
// get the fakequant parameter minq's name
|
||||
if (!is_quant_cnode(x)) {
|
||||
continue;
|
||||
|
|
|
@ -1169,9 +1169,9 @@ class QuantBlock(Cell):
|
|||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = f'quant={self.quant}, core_op={type(self.core_op)}'
|
||||
str_info = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
|
||||
if self.has_bias:
|
||||
str_info = str_info + f', bias={self.bias}'
|
||||
str_info = str_info + f', bias=shape[{self.bias.shape}]'
|
||||
if self.has_act:
|
||||
str_info = str_info + f', activation={self.activation}'
|
||||
str_info = str_info + f', dequant={self.dequant}'
|
||||
|
|
|
@ -237,12 +237,14 @@ class PrimitiveWithInfer(Primitive):
|
|||
"""
|
||||
Infer output shape based on input shape.
|
||||
|
||||
Args:
|
||||
inputs (tuple(int)): dimensions of input tensors.
|
||||
outputs (tuple(int)): dimensions of output tensors.
|
||||
|
||||
Note:
|
||||
The shape of scalar is an empty tuple.
|
||||
|
||||
Args:
|
||||
args (tuple(int)): shapes of input tensors.
|
||||
|
||||
Return:
|
||||
`tuple(int)`, shapes of output tensors.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
@ -251,8 +253,10 @@ class PrimitiveWithInfer(Primitive):
|
|||
Infer output dtype based on input dtype.
|
||||
|
||||
Args:
|
||||
inputs (mstype): data type of inputs.
|
||||
outputs (mstype): data type of outputs.
|
||||
args (:class:`mindspore.dtype`): data type of inputs.
|
||||
|
||||
Return:
|
||||
:class:`mindspore.dtype`, data type of outputs.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
@ -261,8 +265,10 @@ class PrimitiveWithInfer(Primitive):
|
|||
Infer output value based on input value at compile time.
|
||||
|
||||
Args:
|
||||
inputs (any): value of inputs.
|
||||
outputs (any): value of outputs.
|
||||
args (Any): value of inputs.
|
||||
|
||||
Return:
|
||||
Value of outputs. Return `None` for, cat not infer the value at compile time.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
|
|
@ -318,9 +318,12 @@ class ExportToQuantInferNetwork:
|
|||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if info:
|
||||
fack_quant_a_in_op, minq_name = info
|
||||
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
|
||||
minq = self.all_parameters[minq_name]
|
||||
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
||||
if minq_name == 'input':
|
||||
scale_a_in, zp_a_in = self.input_scale, self.input_zero_point
|
||||
else:
|
||||
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
|
||||
minq = self.all_parameters[minq_name]
|
||||
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
||||
else:
|
||||
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
|
||||
return None
|
||||
|
|
|
@ -104,19 +104,20 @@ def weight2int(data, scale, zero_point):
|
|||
raise ValueError("`scale` and `zero_point` should have the same shape.")
|
||||
if scale.shape[0] < 0:
|
||||
raise ValueError("`scale` and `zero_point` shape should greater than zero.")
|
||||
|
||||
if scale.shape[0] == data.shape[0]:
|
||||
# `Conv2d` or `Dense` op weight
|
||||
shape_list = [-1] + [1] * len(data.shape[1:])
|
||||
scale = scale.reshape(shape_list)
|
||||
zero_point = zero_point.reshape(shape_list)
|
||||
elif scale.shape[0] == data.shape[1]:
|
||||
# `DepthwiseConv2d` op weight
|
||||
shape_list = [1, -1] + [1] * len(data.shape[2:])
|
||||
scale = scale.reshape(shape_list)
|
||||
zero_point = zero_point.reshape(shape_list)
|
||||
else:
|
||||
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
||||
if len(scale.shape) > 1:
|
||||
# for perchannel
|
||||
if scale.shape[0] == data.shape[0]:
|
||||
# `Conv2d` or `Dense` op weight
|
||||
shape_list = [-1] + [1] * len(data.shape[1:])
|
||||
scale = scale.reshape(shape_list)
|
||||
zero_point = zero_point.reshape(shape_list)
|
||||
elif scale.shape[0] == data.shape[1]:
|
||||
# `DepthwiseConv2d` op weight
|
||||
shape_list = [1, -1] + [1] * len(data.shape[2:])
|
||||
scale = scale.reshape(shape_list)
|
||||
zero_point = zero_point.reshape(shape_list)
|
||||
else:
|
||||
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
||||
|
||||
return np.round((data / scale) + zero_point)
|
||||
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MobileNetV2"""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def make_divisible(input_x, div_by=8):
|
||||
return int((input_x + div_by) // div_by)
|
||||
|
||||
|
||||
def _conv_bn(in_channel,
|
||||
out_channel,
|
||||
ksize,
|
||||
stride=1):
|
||||
"""Get a conv2d batchnorm and relu layer."""
|
||||
return nn.SequentialCell(
|
||||
[nn.Conv2d(in_channel,
|
||||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride),
|
||||
nn.BatchNorm2d(out_channel)])
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
def __init__(self, inp, oup, stride, expend_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expend_ratio)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
if expend_ratio == 1:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(),
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1),
|
||||
nn.BatchNorm2d(oup)
|
||||
])
|
||||
else:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(),
|
||||
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(),
|
||||
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1),
|
||||
nn.BatchNorm2d(oup)
|
||||
])
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.conv(input_x)
|
||||
if self.use_res_connect:
|
||||
out = input_x + out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV2(nn.Cell):
|
||||
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
|
||||
super(MobileNetV2, self).__init__()
|
||||
_ = input_size
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
inverted_residual_setting = [
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 230, 1, 1],
|
||||
]
|
||||
if width_mul > 1.0:
|
||||
last_channel = make_divisible(last_channel * width_mul)
|
||||
self.last_channel = last_channel
|
||||
features = [_conv_bn(3, input_channel, 3, 2)]
|
||||
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
out_channel = make_divisible(c * width_mul) if t > 1 else c
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
features.append(block(input_channel, out_channel, s, t))
|
||||
else:
|
||||
features.append(block(input_channel, out_channel, 1, t))
|
||||
input_channel = out_channel
|
||||
|
||||
features.append(_conv_bn(input_channel, self.last_channel, 1))
|
||||
|
||||
self.features = nn.SequentialCell(features)
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.classifier = nn.Dense(self.last_channel, num_class)
|
||||
|
||||
def construct(self, input_x):
|
||||
out = input_x
|
||||
out = self.features(out)
|
||||
out = self.mean(out, (2, 3))
|
||||
out = self.classifier(out)
|
||||
return out
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""mobile net v2"""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def make_divisible(input_x, div_by=8):
|
||||
return int((input_x + div_by) // div_by)
|
||||
|
||||
|
||||
def _conv_bn(in_channel,
|
||||
out_channel,
|
||||
ksize,
|
||||
stride=1):
|
||||
"""Get a conv2d batchnorm and relu layer."""
|
||||
return nn.SequentialCell(
|
||||
[nn.Conv2dBnAct(in_channel,
|
||||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
has_bn=True)])
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
def __init__(self, inp, oup, stride, expend_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expend_ratio)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
if expend_ratio == 1:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2dBnAct(hidden_dim,
|
||||
hidden_dim,
|
||||
3,
|
||||
stride,
|
||||
group=hidden_dim,
|
||||
has_bn=True,
|
||||
activation='relu6'),
|
||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||
has_bn=True)
|
||||
])
|
||||
else:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
|
||||
has_bn=True,
|
||||
activation='relu6'),
|
||||
nn.Conv2dBnAct(hidden_dim,
|
||||
hidden_dim,
|
||||
3,
|
||||
stride,
|
||||
group=hidden_dim,
|
||||
has_bn=True,
|
||||
activation='relu6'),
|
||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||
has_bn=True)
|
||||
])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.conv(input_x)
|
||||
if self.use_res_connect:
|
||||
out = self.add(input_x, out)
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV2(nn.Cell):
|
||||
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
|
||||
super(MobileNetV2, self).__init__()
|
||||
_ = input_size
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
inverted_residual_setting = [
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 230, 1, 1],
|
||||
]
|
||||
if width_mul > 1.0:
|
||||
last_channel = make_divisible(last_channel * width_mul)
|
||||
self.last_channel = last_channel
|
||||
features = [_conv_bn(3, input_channel, 3, 2)]
|
||||
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
out_channel = make_divisible(c * width_mul) if t > 1 else c
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
features.append(block(input_channel, out_channel, s, t))
|
||||
else:
|
||||
features.append(block(input_channel, out_channel, 1, t))
|
||||
input_channel = out_channel
|
||||
|
||||
features.append(_conv_bn(input_channel, self.last_channel, 1))
|
||||
|
||||
self.features = nn.SequentialCell(features)
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.classifier = nn.DenseBnAct(self.last_channel, num_class)
|
||||
|
||||
def construct(self, input_x):
|
||||
out = input_x
|
||||
out = self.features(out)
|
||||
out = self.mean(out, (2, 3))
|
||||
out = self.classifier(out)
|
||||
return out
|
|
@ -20,7 +20,7 @@ import mindspore.context as context
|
|||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train.quant import quant as qat
|
||||
from mobilenetv2_combined import MobileNetV2
|
||||
from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
|
|||
def __init__(self, num_class=10):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
|
||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu', pad_mode="valid")
|
||||
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||
|
@ -67,20 +67,19 @@ def test_qat_lenet():
|
|||
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
|
||||
net = LeNet5()
|
||||
net = qat.convert_quant_network(
|
||||
net, freeze_bn=10000, num_bits=8)
|
||||
net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
# should load the checkpoint. mock here
|
||||
for param in net.get_parameters():
|
||||
param.init_data()
|
||||
qat.export_geir(net, img, file_name="quant.pb")
|
||||
qat.export(net, img, file_name="quant.pb")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
|
||||
def test_qat_mobile():
|
||||
net = MobileNetV2()
|
||||
network = mobilenetV2(num_classes=1000)
|
||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||
net = qat.convert_quant_network(
|
||||
net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8)
|
||||
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
# should load the checkpoint. mock here
|
||||
for param in net.get_parameters():
|
||||
for param in network.get_parameters():
|
||||
param.init_data()
|
||||
qat.export_geir(net, img, file_name="quant.pb")
|
||||
qat.export(network, img, file_name="quant.pb")
|
||||
|
|
Loading…
Reference in New Issue