forked from OSSInnovation/mindspore
add QuantDtype and Observer
This commit is contained in:
parent
1e678a84dc
commit
025ea2f392
|
@ -248,6 +248,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/ops
|
||||
${CMAKE_SOURCE_DIR}/mindspore/communication
|
||||
${CMAKE_SOURCE_DIR}/mindspore/profiler
|
||||
${CMAKE_SOURCE_DIR}/mindspore/compression
|
||||
DESTINATION ${INSTALL_PY_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
MindSpore compression module.
|
||||
"""
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Compression common module.
|
||||
"""
|
||||
|
||||
from .constant import *
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Constant module for compression"""
|
||||
import enum
|
||||
import re
|
||||
from types import DynamicClassAttribute
|
||||
|
||||
|
||||
__all__ = ["QuantDtype"]
|
||||
|
||||
|
||||
@enum.unique
|
||||
class QuantDtype(enum.Enum):
|
||||
"""
|
||||
For type switch
|
||||
"""
|
||||
INT2 = "INT2"
|
||||
INT3 = "INT3"
|
||||
INT4 = "INT4"
|
||||
INT5 = "INT5"
|
||||
INT6 = "INT6"
|
||||
INT7 = "INT7"
|
||||
INT8 = "INT8"
|
||||
|
||||
UINT2 = "UINT2"
|
||||
UINT3 = "UINT3"
|
||||
UINT4 = "UINT4"
|
||||
UINT5 = "UINT5"
|
||||
UINT6 = "UINT6"
|
||||
UINT7 = "UINT7"
|
||||
UINT8 = "UINT8"
|
||||
|
||||
FLOAT16 = "FLOAT16"
|
||||
FLOAT32 = "FLOAT32"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}"
|
||||
|
||||
@staticmethod
|
||||
def is_signed(dtype):
|
||||
return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5,
|
||||
QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8]
|
||||
|
||||
@staticmethod
|
||||
def switch_signed(dtype):
|
||||
"""switch signed"""
|
||||
type_map = {
|
||||
QuantDtype.INT2: QuantDtype.UINT2,
|
||||
QuantDtype.INT3: QuantDtype.UINT3,
|
||||
QuantDtype.INT4: QuantDtype.UINT4,
|
||||
QuantDtype.INT5: QuantDtype.UINT5,
|
||||
QuantDtype.INT6: QuantDtype.UINT6,
|
||||
QuantDtype.INT7: QuantDtype.UINT7,
|
||||
QuantDtype.INT8: QuantDtype.UINT8,
|
||||
QuantDtype.UINT2: QuantDtype.INT2,
|
||||
QuantDtype.UINT3: QuantDtype.INT3,
|
||||
QuantDtype.UINT4: QuantDtype.INT4,
|
||||
QuantDtype.UINT5: QuantDtype.INT5,
|
||||
QuantDtype.UINT6: QuantDtype.INT6,
|
||||
QuantDtype.UINT7: QuantDtype.INT7,
|
||||
QuantDtype.UINT8: QuantDtype.INT8
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
@DynamicClassAttribute
|
||||
def value(self):
|
||||
"""The value of the Enum member."""
|
||||
return int(re.search(r"(\d+)", self._value_).group(1))
|
||||
|
||||
@DynamicClassAttribute
|
||||
def num_bits(self):
|
||||
"""The num_bits of the Enum member."""
|
||||
return self.value
|
|
@ -24,6 +24,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator, Rel, twice
|
||||
from mindspore.compression.common import QuantDtype
|
||||
import mindspore.context as context
|
||||
from .normalization import BatchNorm2d, BatchNorm1d
|
||||
from .activation import get_activation, ReLU, LeakyReLU
|
||||
|
@ -277,13 +278,233 @@ class BatchNormFoldCell(Cell):
|
|||
return batch_mean, batch_std, running_mean, running_std
|
||||
|
||||
|
||||
def _partial_init(cls_or_self, **kwargs):
|
||||
"""
|
||||
Wrapper that allows creation of class factories.
|
||||
|
||||
This can be useful when there is a need to create classes with the same
|
||||
constructor arguments, but different instances.
|
||||
|
||||
Example::
|
||||
>>> Foo.partial_init = classmethod(_partial_init)
|
||||
>>> foo_builder = Foo.partial_init(a=3, b=4).partial_init(answer=42)
|
||||
>>> foo_instance1 = foo_builder()
|
||||
>>> foo_instance2 = foo_builder()
|
||||
>>> id(foo_instance1) == id(foo_instance2)
|
||||
False
|
||||
"""
|
||||
|
||||
class _PartialWrapper:
|
||||
r"""
|
||||
class of wrapper that allows creation of class factories.
|
||||
"""
|
||||
|
||||
def __init__(self, p):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, *args, **keywords):
|
||||
return self.p(*args, **keywords)
|
||||
|
||||
def __repr__(self):
|
||||
return self.p.__repr__()
|
||||
|
||||
partial_init = _partial_init
|
||||
|
||||
r = _PartialWrapper(partial(cls_or_self, **kwargs))
|
||||
return r
|
||||
|
||||
|
||||
class Observer(Cell):
|
||||
"""
|
||||
Base class of Observer. Observer is used to calculate the statistics of specific layer.
|
||||
|
||||
Notes:
|
||||
This class is an abstract class.
|
||||
|
||||
Args:
|
||||
quant_dtype (QuantDtype): The type of FakeQuant data.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_dtype):
|
||||
super(Observer, self).__init__()
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
def extend_repr(self):
|
||||
s = f"dtype={self.dtype}"
|
||||
return s
|
||||
|
||||
def construct(self):
|
||||
pass
|
||||
|
||||
partial_init = classmethod(_partial_init)
|
||||
|
||||
|
||||
class UniformQuantObserver(Observer):
|
||||
"""
|
||||
The base class of Uniform Quantization Observer.
|
||||
|
||||
Args:
|
||||
quant_dtype (QuantDtype): The type of FakeQuant data. Default: QuantDtype.INT8.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
||||
num_channels (int): declarate the min and max channel size, Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
"""
|
||||
|
||||
min_max_map = {
|
||||
QuantDtype.INT2: (-2, 1),
|
||||
QuantDtype.INT3: (-4, 3),
|
||||
QuantDtype.INT4: (-8, 7),
|
||||
QuantDtype.INT5: (-16, 15),
|
||||
QuantDtype.INT6: (-32, 31),
|
||||
QuantDtype.INT7: (-64, 63),
|
||||
QuantDtype.INT8: (-128, 127),
|
||||
|
||||
QuantDtype.UINT2: (0, 3),
|
||||
QuantDtype.UINT3: (0, 7),
|
||||
QuantDtype.UINT4: (0, 15),
|
||||
QuantDtype.UINT5: (0, 31),
|
||||
QuantDtype.UINT6: (0, 63),
|
||||
QuantDtype.UINT7: (0, 127),
|
||||
QuantDtype.UINT8: (0, 255)
|
||||
}
|
||||
|
||||
def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
|
||||
num_channels=1):
|
||||
super(UniformQuantObserver, self).__init__(quant_dtype)
|
||||
self.per_channel = per_channel
|
||||
self.symmetric = symmetric
|
||||
self.narrow_range = narrow_range
|
||||
self.num_channels = num_channels
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
||||
r"""
|
||||
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
|
||||
|
||||
Args:
|
||||
min_init (int, float): The initialized min value. Default: -6.
|
||||
max_init (int, float): The initialized max value. Default: 6.
|
||||
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
channel_axis (int): Quantization by channel axis. Default: 1.
|
||||
num_channels (int): declarate the min and max channel size, Default: 1.
|
||||
quant_dtype (QuantDtype): The datatype of quantization, supporting 4 and 8bits. Default: QuantDtype.INT8.
|
||||
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of FakeQuantWithMinMaxObserver.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> fake_quant = FakeQuantWithMinMaxObserver()
|
||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = fake_quant(input_x)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
channel_axis=1,
|
||||
num_channels=1,
|
||||
quant_dtype=QuantDtype.INT8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
"""Initialize FakeQuantWithMinMax layer"""
|
||||
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
|
||||
symmetric=symmetric, narrow_range=narrow_range,
|
||||
num_channels=num_channels)
|
||||
Validator.check_type("min_init", min_init, [int, float])
|
||||
Validator.check_type("max_init", max_init, [int, float])
|
||||
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
|
||||
self.min_init = min_init
|
||||
self.max_init = max_init
|
||||
self.quant_dtype = quant_dtype
|
||||
self.ema = ema
|
||||
self.ema_decay = ema_decay
|
||||
self.per_channel = per_channel
|
||||
self.num_channels = num_channels
|
||||
self.channel_axis = channel_axis
|
||||
self.quant_delay = quant_delay
|
||||
self.symmetric = symmetric
|
||||
self.narrow_range = narrow_range
|
||||
self.is_ascend = context.get_context('device_target') == "Ascend"
|
||||
|
||||
# init tensor min and max for fake quant op
|
||||
if self.per_channel:
|
||||
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
|
||||
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
|
||||
else:
|
||||
min_array = np.array([self.min_init]).astype(np.float32)
|
||||
max_array = np.array([self.max_init]).astype(np.float32)
|
||||
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||
|
||||
# init fake quant relative op
|
||||
if self.per_channel:
|
||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||
else:
|
||||
quant_fun = Q.FakeQuantPerLayer
|
||||
ema_fun = Q.MinMaxUpdatePerLayer
|
||||
|
||||
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
|
||||
if self.is_ascend:
|
||||
self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
quant_delay=self.quant_delay)
|
||||
self.fake_quant_infer = self.fake_quant_train
|
||||
else:
|
||||
quant_fun = partial(quant_fun,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
num_bits=self.quant_dtype.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
quant_delay=self.quant_delay)
|
||||
self.fake_quant_train = quant_fun(training=True)
|
||||
self.fake_quant_infer = quant_fun(training=False)
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
||||
'quant_delay={}, min_init={}, max_init={}'.format(self.quant_dtype, self.symmetric, self.narrow_range,
|
||||
self.ema, self.ema_decay, self.per_channel,
|
||||
self.channel_axis, self.num_channels, self.quant_delay,
|
||||
self.min_init, self.max_init)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
if self.training:
|
||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||
P.Assign()(self.minq, min_up)
|
||||
P.Assign()(self.maxq, max_up)
|
||||
out = self.fake_quant_train(x, self.minq, self.maxq)
|
||||
else:
|
||||
out = self.fake_quant_infer(x, self.minq, self.maxq)
|
||||
return out
|
||||
|
||||
|
||||
class FakeQuantWithMinMax(Cell):
|
||||
r"""
|
||||
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
|
||||
|
||||
Args:
|
||||
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
|
||||
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
|
||||
min_init (int, float): The initialized min value. Default: -6.
|
||||
max_init (int, float): The initialized max value. Default: 6.
|
||||
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
|
|
Loading…
Reference in New Issue