add QuantDtype and Observer

This commit is contained in:
yuchaojie 2020-10-15 21:36:23 +08:00
parent 1e678a84dc
commit 025ea2f392
5 changed files with 345 additions and 2 deletions

View File

@ -248,6 +248,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/ops
${CMAKE_SOURCE_DIR}/mindspore/communication
${CMAKE_SOURCE_DIR}/mindspore/profiler
${CMAKE_SOURCE_DIR}/mindspore/compression
DESTINATION ${INSTALL_PY_DIR}
COMPONENT mindspore
)

View File

@ -0,0 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MindSpore compression module.
"""

View File

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Compression common module.
"""
from .constant import *

View File

@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Constant module for compression"""
import enum
import re
from types import DynamicClassAttribute
__all__ = ["QuantDtype"]
@enum.unique
class QuantDtype(enum.Enum):
"""
For type switch
"""
INT2 = "INT2"
INT3 = "INT3"
INT4 = "INT4"
INT5 = "INT5"
INT6 = "INT6"
INT7 = "INT7"
INT8 = "INT8"
UINT2 = "UINT2"
UINT3 = "UINT3"
UINT4 = "UINT4"
UINT5 = "UINT5"
UINT6 = "UINT6"
UINT7 = "UINT7"
UINT8 = "UINT8"
FLOAT16 = "FLOAT16"
FLOAT32 = "FLOAT32"
def __str__(self):
return f"{self.name}"
@staticmethod
def is_signed(dtype):
return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5,
QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8]
@staticmethod
def switch_signed(dtype):
"""switch signed"""
type_map = {
QuantDtype.INT2: QuantDtype.UINT2,
QuantDtype.INT3: QuantDtype.UINT3,
QuantDtype.INT4: QuantDtype.UINT4,
QuantDtype.INT5: QuantDtype.UINT5,
QuantDtype.INT6: QuantDtype.UINT6,
QuantDtype.INT7: QuantDtype.UINT7,
QuantDtype.INT8: QuantDtype.UINT8,
QuantDtype.UINT2: QuantDtype.INT2,
QuantDtype.UINT3: QuantDtype.INT3,
QuantDtype.UINT4: QuantDtype.INT4,
QuantDtype.UINT5: QuantDtype.INT5,
QuantDtype.UINT6: QuantDtype.INT6,
QuantDtype.UINT7: QuantDtype.INT7,
QuantDtype.UINT8: QuantDtype.INT8
}
return type_map[dtype]
@DynamicClassAttribute
def value(self):
"""The value of the Enum member."""
return int(re.search(r"(\d+)", self._value_).group(1))
@DynamicClassAttribute
def num_bits(self):
"""The num_bits of the Enum member."""
return self.value

View File

@ -24,6 +24,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice
from mindspore.compression.common import QuantDtype
import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU
@ -277,13 +278,233 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std
def _partial_init(cls_or_self, **kwargs):
"""
Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
Example::
>>> Foo.partial_init = classmethod(_partial_init)
>>> foo_builder = Foo.partial_init(a=3, b=4).partial_init(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
class _PartialWrapper:
r"""
class of wrapper that allows creation of class factories.
"""
def __init__(self, p):
self.p = p
def __call__(self, *args, **keywords):
return self.p(*args, **keywords)
def __repr__(self):
return self.p.__repr__()
partial_init = _partial_init
r = _PartialWrapper(partial(cls_or_self, **kwargs))
return r
class Observer(Cell):
"""
Base class of Observer. Observer is used to calculate the statistics of specific layer.
Notes:
This class is an abstract class.
Args:
quant_dtype (QuantDtype): The type of FakeQuant data.
"""
def __init__(self, quant_dtype):
super(Observer, self).__init__()
self.quant_dtype = quant_dtype
def extend_repr(self):
s = f"dtype={self.dtype}"
return s
def construct(self):
pass
partial_init = classmethod(_partial_init)
class UniformQuantObserver(Observer):
"""
The base class of Uniform Quantization Observer.
Args:
quant_dtype (QuantDtype): The type of FakeQuant data. Default: QuantDtype.INT8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
num_channels (int): declarate the min and max channel size, Default: 1.
Returns:
Tensor.
"""
min_max_map = {
QuantDtype.INT2: (-2, 1),
QuantDtype.INT3: (-4, 3),
QuantDtype.INT4: (-8, 7),
QuantDtype.INT5: (-16, 15),
QuantDtype.INT6: (-32, 31),
QuantDtype.INT7: (-64, 63),
QuantDtype.INT8: (-128, 127),
QuantDtype.UINT2: (0, 3),
QuantDtype.UINT3: (0, 7),
QuantDtype.UINT4: (0, 15),
QuantDtype.UINT5: (0, 31),
QuantDtype.UINT6: (0, 63),
QuantDtype.UINT7: (0, 127),
QuantDtype.UINT8: (0, 255)
}
def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
num_channels=1):
super(UniformQuantObserver, self).__init__(quant_dtype)
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.num_channels = num_channels
class FakeQuantWithMinMaxObserver(UniformQuantObserver):
r"""
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
Args:
min_init (int, float): The initialized min value. Default: -6.
max_init (int, float): The initialized max value. Default: 6.
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
num_channels (int): declarate the min and max channel size, Default: 1.
quant_dtype (QuantDtype): The datatype of quantization, supporting 4 and 8bits. Default: QuantDtype.INT8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMaxObserver.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = FakeQuantWithMinMaxObserver()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
def __init__(self,
min_init=-6,
max_init=6,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
num_channels=1,
quant_dtype=QuantDtype.INT8,
symmetric=False,
narrow_range=False,
quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
self.min_init = min_init
self.max_init = max_init
self.quant_dtype = quant_dtype
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.num_channels = num_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.is_ascend = context.get_context('device_target') == "Ascend"
# init tensor min and max for fake quant op
if self.per_channel:
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else:
min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
# init fake quant relative op
if self.per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.MinMaxUpdatePerLayer
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend:
self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_infer = self.fake_quant_train
else:
quant_fun = partial(quant_fun,
ema=self.ema,
ema_decay=ema_decay,
num_bits=self.quant_dtype.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)
def extend_repr(self):
s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(self.quant_dtype, self.symmetric, self.narrow_range,
self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay,
self.min_init, self.max_init)
return s
def construct(self, x):
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out
class FakeQuantWithMinMax(Cell):
r"""
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
Args:
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
min_init (int, float): The initialized min value. Default: -6.
max_init (int, float): The initialized max value. Default: 6.
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.