forked from mindspore-Ecosystem/mindspore
!1504 add custom tbe ops for quant aware training
Merge pull request !1504 from wandongdong/master
This commit is contained in:
commit
c8b30f9290
|
@ -22,11 +22,15 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import check_int_positive, check_bool, twice
|
from mindspore._checkparam import check_int_positive, check_bool, twice
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
|
import mindspore.context as context
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'FakeQuantWithMinMax',
|
'FakeQuantWithMinMax',
|
||||||
|
'DepthwiseConv2dBatchNormQuant',
|
||||||
'Conv2dBatchNormQuant',
|
'Conv2dBatchNormQuant',
|
||||||
'Conv2dQuant',
|
'Conv2dQuant',
|
||||||
'DenseQuant',
|
'DenseQuant',
|
||||||
|
@ -39,6 +43,169 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormFoldCell(Cell):
|
||||||
|
"""
|
||||||
|
Batch normalization folded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||||
|
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||||
|
float32 else 1e-3. Default: 1e-5.
|
||||||
|
freeze_bn (int): Delay in steps at which computation switches from regular batch
|
||||||
|
norm to frozen mean and std. Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
|
||||||
|
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **global_step** (Tensor) - Tensor to record current global step.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tuple of 4 Tensor, the normalized input and the updated parameters.
|
||||||
|
|
||||||
|
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
|
||||||
|
"""init batch norm fold layer"""
|
||||||
|
super(BatchNormFoldCell, self).__init__()
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||||
|
if self.is_gpu:
|
||||||
|
self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
|
||||||
|
self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
|
||||||
|
else:
|
||||||
|
self.bn_reduce = P.BNTrainingReduce()
|
||||||
|
self.bn_update = P.BatchNormFoldD(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
|
||||||
|
|
||||||
|
def construct(self, x, mean, variance, global_step):
|
||||||
|
if self.is_gpu:
|
||||||
|
if self.training:
|
||||||
|
batch_mean, batch_std, running_mean, running_std = self.bn_train(x, mean, variance, global_step)
|
||||||
|
else:
|
||||||
|
batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
|
||||||
|
else:
|
||||||
|
if self.training:
|
||||||
|
x_sum, x_square_sum = self.bn_reduce(x)
|
||||||
|
_, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
|
||||||
|
self.bn_update(x, x_sum, x_square_sum, mean, variance)
|
||||||
|
P.Assign()(mean, mean_updated)
|
||||||
|
P.Assign()(variance, variance_updated)
|
||||||
|
else:
|
||||||
|
batch_mean = P.ZerosLike()(variance)
|
||||||
|
batch_std = P.OnesLike()(variance)
|
||||||
|
running_mean = P.TensorAdd()(mean, 0.)
|
||||||
|
running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon))
|
||||||
|
return batch_mean, batch_std, running_mean, running_std
|
||||||
|
|
||||||
|
|
||||||
|
class FakeQuantWithMinMaxD(Cell):
|
||||||
|
r"""
|
||||||
|
Aware Quantization training op of ascend. This OP provide Fake quantization observer
|
||||||
|
function on data with min and max.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
|
||||||
|
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
|
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
|
||||||
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
|
||||||
|
per_channel (bool): Quantization by layer or channel. Default: False.
|
||||||
|
out_channels (int): declarate the min and max channel size, Default: 1.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - The input of FakeQuantWithMinMax.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> fake_quant = nn.FakeQuantWithMinMaxD()
|
||||||
|
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||||
|
>>> result = fake_quant(input_x)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
min_init=-6,
|
||||||
|
max_init=6,
|
||||||
|
num_bits=8,
|
||||||
|
ema=False,
|
||||||
|
ema_decay=0.999,
|
||||||
|
per_channel=False,
|
||||||
|
channel_size=1,
|
||||||
|
quant_delay=0,
|
||||||
|
symmetric=False,
|
||||||
|
narrow_range=False,
|
||||||
|
training=True):
|
||||||
|
"""init FakeQuantWithMinMax ascend layer"""
|
||||||
|
super(FakeQuantWithMinMaxD, self).__init__()
|
||||||
|
|
||||||
|
self.min_init = min_init
|
||||||
|
self.num_bits = num_bits
|
||||||
|
self.max_init = max_init
|
||||||
|
self.ema = ema
|
||||||
|
self.ema_decay = ema_decay
|
||||||
|
self.per_channel = per_channel
|
||||||
|
self.channel_size = channel_size
|
||||||
|
self.quant_delay = quant_delay
|
||||||
|
self.symmetric = symmetric
|
||||||
|
self.narrow_range = narrow_range
|
||||||
|
self.training = training
|
||||||
|
|
||||||
|
if not per_channel:
|
||||||
|
self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
||||||
|
ema=self.ema,
|
||||||
|
ema_decay=self.ema_decay,
|
||||||
|
quant_delay=self.quant_delay,
|
||||||
|
symmetric=self.symmetric,
|
||||||
|
narrow_range=self.narrow_range,
|
||||||
|
training=training)
|
||||||
|
self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits,
|
||||||
|
ema=self.ema,
|
||||||
|
ema_decay=self.ema_decay,
|
||||||
|
quant_delay=self.quant_delay,
|
||||||
|
symmetric=self.symmetric,
|
||||||
|
narrow_range=self.narrow_range,
|
||||||
|
training=training)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("not support per channel")
|
||||||
|
|
||||||
|
if isinstance(min_init, Parameter):
|
||||||
|
self.minq = min_init
|
||||||
|
self.maxq = max_init
|
||||||
|
else:
|
||||||
|
self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)),
|
||||||
|
name='quant_min',
|
||||||
|
requires_grad=False)
|
||||||
|
self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)),
|
||||||
|
name='quant_max',
|
||||||
|
requires_grad=False)
|
||||||
|
self.reduce_min = P.ReduceMin()
|
||||||
|
self.reduce_max = P.ReduceMax()
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
|
||||||
|
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
|
||||||
|
self.quant_delay)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def construct(self, x, minq, maxq):
|
||||||
|
if self.training:
|
||||||
|
min_up, max_up = self.ema_update(x, minq, maxq)
|
||||||
|
out = self.fake_quant(x, min_up, max_up)
|
||||||
|
P.Assign()(self.minq, min_up)
|
||||||
|
P.Assign()(self.maxq, max_up)
|
||||||
|
else:
|
||||||
|
out = self.fake_quant(x, minq, maxq)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FakeQuantWithMinMax(Cell):
|
class FakeQuantWithMinMax(Cell):
|
||||||
r"""
|
r"""
|
||||||
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
|
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
|
||||||
|
@ -62,7 +229,7 @@ class FakeQuantWithMinMax(Cell):
|
||||||
Tensor, with the same type and shape as the `x`.
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> fake_quant = nn.FakeQuantWithMinMax()
|
>>> fake_quant = FakeQuantWithMinMax()
|
||||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||||
>>> result = fake_quant(input_x)
|
>>> result = fake_quant(input_x)
|
||||||
"""
|
"""
|
||||||
|
@ -77,7 +244,9 @@ class FakeQuantWithMinMax(Cell):
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
quant_delay=0,
|
quant_delay=0,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
training=True):
|
||||||
|
"""init FakeQuantWithMinMax layer"""
|
||||||
super(FakeQuantWithMinMax, self).__init__()
|
super(FakeQuantWithMinMax, self).__init__()
|
||||||
|
|
||||||
self.min_init = min_init
|
self.min_init = min_init
|
||||||
|
@ -90,12 +259,13 @@ class FakeQuantWithMinMax(Cell):
|
||||||
self.quant_delay = quant_delay
|
self.quant_delay = quant_delay
|
||||||
self.symmetric = symmetric
|
self.symmetric = symmetric
|
||||||
self.narrow_range = narrow_range
|
self.narrow_range = narrow_range
|
||||||
|
self.training = training
|
||||||
|
|
||||||
if per_channel:
|
if per_channel:
|
||||||
min_array = np.array([self.min_init for i in range(
|
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
|
||||||
0, self.out_channels)]).astype(np.float32)
|
max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32)
|
||||||
max_array = np.array([self.max_init for i in range(
|
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||||
0, self.out_channels)]).astype(np.float32)
|
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||||
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
|
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
|
||||||
ema=self.ema,
|
ema=self.ema,
|
||||||
ema_decay=self.ema_decay,
|
ema_decay=self.ema_decay,
|
||||||
|
@ -113,25 +283,44 @@ class FakeQuantWithMinMax(Cell):
|
||||||
else:
|
else:
|
||||||
min_array = np.array([min_init]).reshape(1).astype(np.float32)
|
min_array = np.array([min_init]).reshape(1).astype(np.float32)
|
||||||
max_array = np.array([max_init]).reshape(1).astype(np.float32)
|
max_array = np.array([max_init]).reshape(1).astype(np.float32)
|
||||||
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||||
ema=self.ema,
|
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||||
ema_decay=self.ema_decay,
|
if context.get_context('device_target') == "Ascend":
|
||||||
quant_delay=self.quant_delay,
|
self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits,
|
||||||
symmetric=self.symmetric,
|
ema=self.ema,
|
||||||
narrow_range=self.narrow_range,
|
ema_decay=self.ema_decay,
|
||||||
training=True)
|
quant_delay=self.quant_delay,
|
||||||
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
symmetric=self.symmetric,
|
||||||
ema=self.ema,
|
narrow_range=self.narrow_range,
|
||||||
ema_decay=self.ema_decay,
|
training=True,
|
||||||
quant_delay=self.quant_delay,
|
min_init=self.minq,
|
||||||
symmetric=self.symmetric,
|
max_init=self.maxq)
|
||||||
narrow_range=self.narrow_range,
|
self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits,
|
||||||
training=False)
|
ema=self.ema,
|
||||||
|
ema_decay=self.ema_decay,
|
||||||
self.minq = Parameter(
|
quant_delay=self.quant_delay,
|
||||||
Tensor(min_array), name='quant_min', requires_grad=False)
|
symmetric=self.symmetric,
|
||||||
self.maxq = Parameter(
|
narrow_range=self.narrow_range,
|
||||||
Tensor(max_array), name='quant_max', requires_grad=False)
|
training=False,
|
||||||
|
min_init=self.minq,
|
||||||
|
max_init=self.maxq)
|
||||||
|
elif context.get_context('device_target') == "GPU":
|
||||||
|
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
||||||
|
ema=self.ema,
|
||||||
|
ema_decay=self.ema_decay,
|
||||||
|
quant_delay=self.quant_delay,
|
||||||
|
symmetric=self.symmetric,
|
||||||
|
narrow_range=self.narrow_range,
|
||||||
|
training=True)
|
||||||
|
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
||||||
|
ema=self.ema,
|
||||||
|
ema_decay=ema_decay,
|
||||||
|
quant_delay=quant_delay,
|
||||||
|
symmetric=self.symmetric,
|
||||||
|
narrow_range=self.narrow_range,
|
||||||
|
training=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not support platform.")
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
|
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
|
||||||
|
@ -146,6 +335,191 @@ class FakeQuantWithMinMax(Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseConv2dBatchNormQuant(Cell):
|
||||||
|
r"""
|
||||||
|
2D depthwise convolution with BatchNormal op folded layer.
|
||||||
|
|
||||||
|
For a more Detailed overview of Conv2d op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||||
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||||
|
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
|
||||||
|
stride (int): Specifies stride for all spatial dimensions with the same value.
|
||||||
|
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||||
|
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
|
eps (int): Parameters for BatchNormal. Default: 1e-5.
|
||||||
|
momentum (int): Parameters for BatchNormal op. Default: 0.9.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
|
convolution kernel. Default: 'None'.
|
||||||
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
|
beta vector. Default: 'None'.
|
||||||
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
|
gamma vector. Default: 'None'.
|
||||||
|
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
|
mean vector. Default: 'None'.
|
||||||
|
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
|
variance vector. Default: 'None'.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
||||||
|
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||||
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> quant = nn.DepthwiseConv2dBatchNormQuant(1, 6,
|
||||||
|
kernel_size= (2, 2),
|
||||||
|
stride=(1, 1),
|
||||||
|
pad_mode="valid",
|
||||||
|
>>> dilation=(1, 1))
|
||||||
|
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
|
||||||
|
>>> result = quant(input_x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
eps=1e-5,
|
||||||
|
momentum=0.997,
|
||||||
|
weight_init=None,
|
||||||
|
beta_init=None,
|
||||||
|
gamma_init=None,
|
||||||
|
mean_init=None,
|
||||||
|
var_init=None,
|
||||||
|
quant_delay=0,
|
||||||
|
freeze_bn=100000,
|
||||||
|
fake=True,
|
||||||
|
num_bits=8,
|
||||||
|
per_channel=False,
|
||||||
|
symmetric=False,
|
||||||
|
narrow_range=False):
|
||||||
|
"""init DepthwiseConv2dBatchNormQuant layer"""
|
||||||
|
super(DepthwiseConv2dBatchNormQuant, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = twice(dilation)
|
||||||
|
self.stride = twice(stride)
|
||||||
|
self.group = group
|
||||||
|
self.fake = fake
|
||||||
|
self.freeze_bn = freeze_bn
|
||||||
|
self.momentum = momentum
|
||||||
|
self.quant_delay = quant_delay
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
self.kernel_size = (kernel_size, kernel_size)
|
||||||
|
else:
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
if group > 1:
|
||||||
|
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant')
|
||||||
|
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant')
|
||||||
|
self.is_depthwise = group > 1
|
||||||
|
|
||||||
|
channel_multiplier = out_channels // in_channels
|
||||||
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
pad=padding)
|
||||||
|
|
||||||
|
if weight_init is None:
|
||||||
|
weight_init = initializer('normal', [channel_multiplier, in_channels, *kernel_size])
|
||||||
|
self.weight = Parameter(weight_init, name='weight')
|
||||||
|
if gamma_init is None:
|
||||||
|
gamma_init = initializer('ones', [out_channels])
|
||||||
|
self.gamma = Parameter(gamma_init, name='gamma')
|
||||||
|
if beta_init is None:
|
||||||
|
beta_init = initializer('zeros', [out_channels])
|
||||||
|
self.beta = Parameter(beta_init, name='beta')
|
||||||
|
if mean_init is None:
|
||||||
|
mean_init = initializer('zeros', [out_channels])
|
||||||
|
self.moving_mean = Parameter(
|
||||||
|
mean_init, name='moving_mean', requires_grad=False)
|
||||||
|
if var_init is None:
|
||||||
|
var_init = initializer('ones', [out_channels])
|
||||||
|
self.moving_variance = Parameter(
|
||||||
|
var_init, name='moving_variance', requires_grad=False)
|
||||||
|
|
||||||
|
self.step = Parameter(initializer(
|
||||||
|
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
|
||||||
|
|
||||||
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||||
|
max_init=6,
|
||||||
|
ema=False,
|
||||||
|
num_bits=num_bits,
|
||||||
|
quant_delay=quant_delay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
out_channels=out_channels,
|
||||||
|
symmetric=symmetric,
|
||||||
|
narrow_range=narrow_range)
|
||||||
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
||||||
|
|
||||||
|
self.correct_mul = P.CorrectionMul(self.is_depthwise)
|
||||||
|
if context.get_context('device_target') == "Ascend":
|
||||||
|
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
|
||||||
|
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
|
||||||
|
elif context.get_context('device_target') == "GPU":
|
||||||
|
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn)
|
||||||
|
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not support platform.")
|
||||||
|
self.one = Tensor(1, mstype.int32)
|
||||||
|
self.assignadd = P.AssignAdd()
|
||||||
|
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||||
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||||
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
|
||||||
|
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
|
self.pad_mode, self.padding, self.dilation, self.group,
|
||||||
|
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out_conv = self.conv(x, self.weight)
|
||||||
|
# BN fold1
|
||||||
|
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
|
||||||
|
self.moving_mean,
|
||||||
|
self.moving_variance,
|
||||||
|
self.step)
|
||||||
|
# fake weight
|
||||||
|
weight = self.correct_mul(self.weight, self.gamma, running_std)
|
||||||
|
if self.fake:
|
||||||
|
weight = self.fake_quant_weight(weight)
|
||||||
|
out = self.conv(x, weight)
|
||||||
|
# BN fold2
|
||||||
|
if self.is_gpu:
|
||||||
|
if self.training:
|
||||||
|
out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
|
||||||
|
batch_std, batch_mean, running_std, running_mean, self.step)
|
||||||
|
F.control_depend(out, self.assignadd(self.step, self.one))
|
||||||
|
else:
|
||||||
|
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
|
||||||
|
batch_std, batch_mean, running_std, running_mean, self.step)
|
||||||
|
else:
|
||||||
|
if self.training:
|
||||||
|
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
|
||||||
|
F.control_depend(out, self.assignadd(self.step, self.one))
|
||||||
|
else:
|
||||||
|
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Conv2dBatchNormQuant(Cell):
|
class Conv2dBatchNormQuant(Cell):
|
||||||
r"""
|
r"""
|
||||||
2D convolution with BatchNormal op folded layer.
|
2D convolution with BatchNormal op folded layer.
|
||||||
|
@ -215,6 +589,7 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False):
|
||||||
|
"""init Conv2dBatchNormQuant layer"""
|
||||||
super(Conv2dBatchNormQuant, self).__init__()
|
super(Conv2dBatchNormQuant, self).__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
@ -231,7 +606,6 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
self.kernel_size = (kernel_size, kernel_size)
|
self.kernel_size = (kernel_size, kernel_size)
|
||||||
else:
|
else:
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
if weight_init is None:
|
if weight_init is None:
|
||||||
weight_init = initializer(
|
weight_init = initializer(
|
||||||
'normal', [out_channels, in_channels // group, *self.kernel_size])
|
'normal', [out_channels, in_channels // group, *self.kernel_size])
|
||||||
|
@ -254,14 +628,6 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
self.step = Parameter(initializer(
|
self.step = Parameter(initializer(
|
||||||
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
|
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
|
||||||
|
|
||||||
self.conv = P.Conv2D(out_channel=self.out_channels,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
mode=1,
|
|
||||||
pad_mode=self.pad_mode,
|
|
||||||
pad=self.padding,
|
|
||||||
stride=self.stride,
|
|
||||||
dilation=self.dilation,
|
|
||||||
group=self.group)
|
|
||||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=False,
|
ema=False,
|
||||||
|
@ -271,23 +637,29 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range)
|
||||||
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
||||||
momentum=momentum,
|
self.conv = P.Conv2D(out_channel=out_channels,
|
||||||
is_training=True,
|
kernel_size=kernel_size,
|
||||||
freeze_bn=freeze_bn)
|
mode=1,
|
||||||
self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps,
|
pad_mode=pad_mode,
|
||||||
momentum=momentum,
|
pad=padding,
|
||||||
is_training=False,
|
stride=stride,
|
||||||
freeze_bn=freeze_bn)
|
dilation=1,
|
||||||
|
group=group)
|
||||||
self.correct_mul = P.CorrectionMul()
|
self.correct_mul = P.CorrectionMul()
|
||||||
self.relu = P.ReLU()
|
if context.get_context('device_target') == "Ascend":
|
||||||
self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn)
|
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
|
||||||
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
|
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
|
||||||
|
elif context.get_context('device_target') == "GPU":
|
||||||
|
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn)
|
||||||
|
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not support platform.")
|
||||||
self.one = Tensor(1, mstype.int32)
|
self.one = Tensor(1, mstype.int32)
|
||||||
self.assignadd = P.AssignAdd()
|
self.assignadd = P.AssignAdd()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||||
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
|
||||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
|
@ -296,34 +668,32 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
if self.training:
|
out_conv = self.conv(x, self.weight)
|
||||||
beta = self.beta
|
# BN fold1
|
||||||
gamma = self.gamma
|
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
|
||||||
gmean = self.moving_mean
|
self.moving_mean,
|
||||||
gvar = self.moving_variance
|
self.moving_variance,
|
||||||
step = self.step
|
self.step)
|
||||||
out_conv = self.conv(x, self.weight)
|
# fake weight
|
||||||
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train(
|
weight = self.correct_mul(self.weight, self.gamma, running_std)
|
||||||
out_conv, gmean, gvar, step)
|
if self.fake:
|
||||||
# BN fold1
|
weight = self.fake_quant_weight(weight)
|
||||||
weight = self.correct_mul(self.weight, gamma, running_std)
|
out = self.conv(x, weight)
|
||||||
if self.fake:
|
# BN fold2
|
||||||
weight = self.fake_quant_weight(weight)
|
if self.is_gpu:
|
||||||
out = self.conv(x, weight)
|
if self.training:
|
||||||
# BN fold2
|
out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
|
||||||
out = self.batchnorm_fold2(
|
batch_std, batch_mean, running_std, running_mean, self.step)
|
||||||
out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step)
|
F.control_depend(out, self.assignadd(self.step, self.one))
|
||||||
F.control_depend(out, self.assignadd(self.step, self.one))
|
else:
|
||||||
|
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
|
||||||
|
batch_std, batch_mean, running_std, running_mean, self.step)
|
||||||
else:
|
else:
|
||||||
step = self.step
|
if self.training:
|
||||||
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
|
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
|
||||||
x, self.moving_mean, self.moving_variance, step)
|
F.control_depend(out, self.assignadd(self.step, self.one))
|
||||||
weight = self.correct_mul(self.weight, self.gamma, running_std)
|
else:
|
||||||
if self.fake:
|
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
|
||||||
weight = self.fake_quant_weight(weight)
|
|
||||||
out = self.conv(x, weight)
|
|
||||||
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean,
|
|
||||||
running_std, running_mean, step)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -434,7 +804,7 @@ class Conv2dQuant(Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||||
'has_bias={}, quant_delay={}'.format(
|
'has_bias={}, quant_delay={}'.format(
|
||||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
|
||||||
@bprop_getters.register(P.FakeQuantWithMinMax)
|
@bprop_getters.register(P.FakeQuantWithMinMax)
|
||||||
def get_bprop_fakequant_with_minmax(self):
|
def get_bprop_fakequant_with_minmax(self):
|
||||||
"""Generate bprop for FakeQuantWithMinMax"""
|
"""Generate bprop for FakeQuantWithMinMax for GPU and Ascend"""
|
||||||
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
|
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
|
||||||
|
|
||||||
def bprop(x, x_min, x_max, out, dout):
|
def bprop(x, x_min, x_max, out, dout):
|
||||||
|
@ -34,7 +34,7 @@ def get_bprop_fakequant_with_minmax(self):
|
||||||
|
|
||||||
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
|
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
|
||||||
def get_bprop_fakequant_with_minmax_perchannel(self):
|
def get_bprop_fakequant_with_minmax_perchannel(self):
|
||||||
"""Generate bprop for FakeQuantWithMinMaxPerChannel"""
|
"""Generate bprop for FakeQuantWithMinMaxPerChannel for GPU"""
|
||||||
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
|
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
|
||||||
|
|
||||||
def bprop(x, x_min, x_max, out, dout):
|
def bprop(x, x_min, x_max, out, dout):
|
||||||
|
@ -46,7 +46,7 @@ def get_bprop_fakequant_with_minmax_perchannel(self):
|
||||||
|
|
||||||
@bprop_getters.register(P.BatchNormFold)
|
@bprop_getters.register(P.BatchNormFold)
|
||||||
def get_bprop_batchnorm_fold(self):
|
def get_bprop_batchnorm_fold(self):
|
||||||
"""Generate bprop for BatchNormFold"""
|
"""Generate bprop for BatchNormFold for GPU"""
|
||||||
op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn)
|
op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn)
|
||||||
|
|
||||||
def bprop(x, mean, variance, global_step, out, dout):
|
def bprop(x, mean, variance, global_step, out, dout):
|
||||||
|
@ -58,8 +58,8 @@ def get_bprop_batchnorm_fold(self):
|
||||||
|
|
||||||
@bprop_getters.register(P.CorrectionMul)
|
@bprop_getters.register(P.CorrectionMul)
|
||||||
def get_bprop_correction_mul(self):
|
def get_bprop_correction_mul(self):
|
||||||
"""Generate bprop for CorrectionMul"""
|
"""Generate bprop for CorrectionMul for Ascend and GPU"""
|
||||||
grad = P.CorrectionMulGrad()
|
grad = P.CorrectionMulGrad(self.channel_axis)
|
||||||
|
|
||||||
def bprop(x, batch_std, running_std, out, dout):
|
def bprop(x, batch_std, running_std, out, dout):
|
||||||
dx, d_batch_std = grad(dout, x, batch_std, running_std)
|
dx, d_batch_std = grad(dout, x, batch_std, running_std)
|
||||||
|
@ -70,7 +70,7 @@ def get_bprop_correction_mul(self):
|
||||||
|
|
||||||
@bprop_getters.register(P.BatchNormFold2)
|
@bprop_getters.register(P.BatchNormFold2)
|
||||||
def get_bprop_batchnorm_fold2(self):
|
def get_bprop_batchnorm_fold2(self):
|
||||||
"""Generate bprop for CorrectionAdd"""
|
"""Generate bprop for BatchNormFold2 for GPU"""
|
||||||
op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn)
|
op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn)
|
||||||
|
|
||||||
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout):
|
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout):
|
||||||
|
@ -80,3 +80,48 @@ def get_bprop_batchnorm_fold2(self):
|
||||||
zeros_like(global_step)
|
zeros_like(global_step)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BatchNormFoldD)
|
||||||
|
def get_bprop_BatchNormFold(self):
|
||||||
|
"""Generate bprop for BatchNormFold for Ascend"""
|
||||||
|
op = P.BatchNormFoldGrad_(self.epsilon, self.is_training, self.freeze_bn)
|
||||||
|
|
||||||
|
def bprop(x, x_sum, x_square_sum, mean, variance, out, dout):
|
||||||
|
dx = op(dout[1], dout[2], x, out[1], out[2])
|
||||||
|
return dx, zeros_like(x_sum), zeros_like(x_square_sum), zeros_like(mean), zeros_like(variance)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BNTrainingReduce)
|
||||||
|
def get_bprop_BNTrainingReduce(self):
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
return (zeros_like(x),)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BatchNormFold2_D)
|
||||||
|
def get_bprop_batchnorm_fold2_(self):
|
||||||
|
"""Generate bprop for BatchNormFold2 for Ascend"""
|
||||||
|
op_reduce = P.BatchNormFold2GradReduce(freeze_bn=self.freeze_bn)
|
||||||
|
op_f = P.BatchNormFold2GradD(freeze_bn=self.freeze_bn)
|
||||||
|
|
||||||
|
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, out, dout):
|
||||||
|
dout_reduce, dout_x_reduce = op_reduce(dout, x)
|
||||||
|
d_batch_std, d_batch_mean, d_gamma, d_x = op_f(dout, dout_reduce, dout_x_reduce, gamma, batch_std,
|
||||||
|
batch_mean, running_std)
|
||||||
|
return d_x, dout_reduce, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.FakeQuantWithMinMaxUpdate)
|
||||||
|
def get_bprop_fakequant_with_minmax_update(self):
|
||||||
|
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
|
||||||
|
|
||||||
|
def bprop(x, x_min, x_max, out, dout):
|
||||||
|
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""_BatchNormFold op"""
|
||||||
|
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
from te import tvm
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
|
||||||
|
batch_norm_op_info = TBERegOp("BatchNormFoldD") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batchnorm_fold.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batchnorm_fold") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("momentum", "optional", "float", "all") \
|
||||||
|
.attr("epsilon", "optional", "float", "all") \
|
||||||
|
.attr("is_training", "optional", "bool", "all") \
|
||||||
|
.attr("freeze_bn", "optional", "int", "all") \
|
||||||
|
.attr("data_format", "optional", "str", "all") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "x_sum", False, "required", "all") \
|
||||||
|
.input(2, "x_square_sum", False, "required", "all") \
|
||||||
|
.input(3, "mean", False, "required", "all") \
|
||||||
|
.input(4, "variance", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.output(1, "batch_mean", False, "required", "all") \
|
||||||
|
.output(2, "batch_std", False, "required", "all") \
|
||||||
|
.output(3, "running_mean", False, "required", "all") \
|
||||||
|
.output(4, "running_std", False, "required", "all") \
|
||||||
|
.output(5, "mean_updated", False, "required", "all") \
|
||||||
|
.output(6, "variance_updated", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batch_norm_op_info)
|
||||||
|
def _batchnorm_fold_tbe():
|
||||||
|
"""_BatchNormFold TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict,
|
||||||
|
dict, dict, dict, dict, dict, dict, dict,
|
||||||
|
float, float, bool, int, str, str)
|
||||||
|
def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
||||||
|
y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated,
|
||||||
|
momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW",
|
||||||
|
kernel_name="batchnorm_fold"):
|
||||||
|
"""batchnorm_fold TBE op"""
|
||||||
|
momentum = 1.0 - momentum
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
data_format = data_format.upper()
|
||||||
|
if data_format != "NCHW":
|
||||||
|
raise RuntimeError("The data_format only support NCHW")
|
||||||
|
|
||||||
|
shape_x = x.get("shape")
|
||||||
|
shape_mean = mean.get("shape")
|
||||||
|
shape_variance = variance.get("shape")
|
||||||
|
dtype_x = x.get("dtype")
|
||||||
|
dtype_mean = mean.get("dtype")
|
||||||
|
dtype_variance = variance.get("dtype")
|
||||||
|
for shape in (shape_x, shape_mean, shape_variance):
|
||||||
|
util.check_shape_rule(shape)
|
||||||
|
util.check_tensor_shape_size(shape)
|
||||||
|
check_tuple = ("float16", "float32")
|
||||||
|
for dtype in (dtype_x, dtype_mean, dtype_variance):
|
||||||
|
util.check_dtype_rule(dtype.lower(), check_tuple)
|
||||||
|
|
||||||
|
format_data = x.get("format").upper()
|
||||||
|
if format_data not in ("NCHW", "NC1HWC0"):
|
||||||
|
raise RuntimeError("Format of input only support 4D and 5HD")
|
||||||
|
|
||||||
|
if format_data == "NC1HWC0":
|
||||||
|
if len(shape_x) != 5:
|
||||||
|
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||||
|
"when input format is NC1HWC0")
|
||||||
|
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||||
|
elif format_data == "NCHW":
|
||||||
|
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||||
|
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||||
|
if shape_x[1] != shape_mean[0]:
|
||||||
|
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||||
|
"be equal to the second axis of shape_x")
|
||||||
|
shape_mean = (1, shape_x[1],)
|
||||||
|
for _ in range(2, len(shape_x)):
|
||||||
|
shape_mean = shape_mean + (1,)
|
||||||
|
|
||||||
|
x_input = tvm.placeholder(shape_x, name="x_input", dtype=dtype_x.lower())
|
||||||
|
x_sum = tvm.placeholder(shape_mean, name="x_sum", dtype=dtype_x.lower())
|
||||||
|
x_square_sum = tvm.placeholder(shape_mean, name="x_square_sum", dtype=dtype_x.lower())
|
||||||
|
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
|
||||||
|
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
|
||||||
|
|
||||||
|
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||||
|
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||||
|
num_rec = 1.0 / num
|
||||||
|
|
||||||
|
# compute the mean of x
|
||||||
|
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||||
|
|
||||||
|
# compute the variance of x
|
||||||
|
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||||
|
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||||
|
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||||
|
|
||||||
|
if num == 1:
|
||||||
|
batch_var_scaler = 0.0
|
||||||
|
else:
|
||||||
|
batch_var_scaler = float(num) / (num - 1)
|
||||||
|
batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||||
|
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon))
|
||||||
|
|
||||||
|
factor = 1.0 - momentum
|
||||||
|
factor_reverse = momentum
|
||||||
|
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||||
|
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||||
|
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||||
|
|
||||||
|
var_mul = te.lang.cce.vmuls(batch_variance, factor)
|
||||||
|
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||||
|
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||||
|
|
||||||
|
y = te.lang.cce.vadds(x_input, 0.0)
|
||||||
|
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||||
|
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||||
|
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res)
|
||||||
|
config = {"name": kernel_name,
|
||||||
|
"tensor_list": [x_input, x_sum, x_square_sum, mean, variance] + res}
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""_BatchNormFold2 op"""
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
SHAPE_SIZE_LIMIT = 2147483648
|
||||||
|
|
||||||
|
batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batchnorm_fold2.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batchnorm_fold2") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.input(0, "x", None, "required", None) \
|
||||||
|
.input(1, "beta", None, "required", None) \
|
||||||
|
.input(2, "gamma", None, "required", None) \
|
||||||
|
.input(3, "batch_std", None, "required", None) \
|
||||||
|
.input(4, "batch_mean", None, "required", None) \
|
||||||
|
.input(5, "running_std", None, "required", None) \
|
||||||
|
.output(0, "y", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batchnorm_fold2_op_info)
|
||||||
|
def _batchnorm_fold2_tbe():
|
||||||
|
"""_BatchNormFold2 TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("batchnorm_fold2")
|
||||||
|
def batchnorm_fold2_compute(x, beta, gamma, batch_std, batch_mean, running_std, kernel_name="batchnorm_fold2"):
|
||||||
|
"""_BatchNormFold2 compute"""
|
||||||
|
shape_x = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
|
factor = te.lang.cce.vdiv(running_std, batch_std)
|
||||||
|
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||||
|
res = te.lang.cce.vmul(x, factor_b)
|
||||||
|
bias = te.lang.cce.vdiv(batch_mean, batch_std)
|
||||||
|
bias = te.lang.cce.vmul(bias, gamma)
|
||||||
|
bias = te.lang.cce.vsub(beta, bias)
|
||||||
|
bias_b = te.lang.cce.broadcast(bias, shape_x)
|
||||||
|
res = te.lang.cce.vadd(res, bias_b)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict, dict, dict, str)
|
||||||
|
def batchnorm_fold2(x, beta, gamma, batch_std, batch_mean, running_std, y, kernel_name="batchnorm_fold2"):
|
||||||
|
"""_BatchNormFold2 op"""
|
||||||
|
shape = x.get("shape")
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(shape)
|
||||||
|
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||||
|
check_list = ["float16", "float32"]
|
||||||
|
inp_dtype = x.get("dtype").lower()
|
||||||
|
if not inp_dtype in check_list:
|
||||||
|
raise RuntimeError("Dtype of input only support float16, float32")
|
||||||
|
data_format = x.get("format")
|
||||||
|
ori_format = x.get("ori_format")
|
||||||
|
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||||
|
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||||
|
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||||
|
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||||
|
shape_c = gamma.get("shape")
|
||||||
|
if gamma.get("format").upper() == "NCHW":
|
||||||
|
shape_c = 1, gamma.get("shape")[0], 1, 1
|
||||||
|
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||||
|
beta_t = tvm.placeholder(shape_c, name="beta", dtype=inp_dtype)
|
||||||
|
gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype)
|
||||||
|
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
|
||||||
|
batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype)
|
||||||
|
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
|
||||||
|
|
||||||
|
res = batchnorm_fold2_compute(x_t, beta_t, gamma_t, batch_std_t, batch_mean_t,
|
||||||
|
running_std_t, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res)
|
||||||
|
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": [x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, running_std_t, res]}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""_BatchNormFold2Grad op"""
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
SHAPE_SIZE_LIMIT = 2147483648
|
||||||
|
|
||||||
|
batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batchnorm_fold2_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batchnorm_fold2_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.input(0, "dout", None, "required", None) \
|
||||||
|
.input(1, "dout_reduce", None, "required", None) \
|
||||||
|
.input(2, "dout_x_reduce", None, "required", None) \
|
||||||
|
.input(3, "gamma", None, "required", None) \
|
||||||
|
.input(4, "batch_std", None, "required", None) \
|
||||||
|
.input(5, "batch_mean", None, "required", None) \
|
||||||
|
.input(6, "running_std", None, "required", None) \
|
||||||
|
.output(0, "d_batch_std", True, "required", "all") \
|
||||||
|
.output(1, "d_batch_mean", True, "required", "all") \
|
||||||
|
.output(2, "d_gamma", True, "required", "all") \
|
||||||
|
.output(3, "dx", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batchnorm_fold2_grad_op_info)
|
||||||
|
def _batchnorm_fold2_grad_tbe():
|
||||||
|
"""_BatchNormFold2Grad TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("batchnorm_fold2_grad")
|
||||||
|
def batchnorm_fold2_grad_compute(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std,
|
||||||
|
kernel_name="batchnorm_fold2_grad"):
|
||||||
|
"""_BatchNormFold2Grad"""
|
||||||
|
shape_x = te.lang.cce.util.shape_to_list(dout.shape)
|
||||||
|
|
||||||
|
d_batch_std_1 = te.lang.cce.vmul(dout_reduce, batch_mean)
|
||||||
|
d_batch_std_1 = te.lang.cce.vmul(d_batch_std_1, gamma)
|
||||||
|
d_batch_std_2 = te.lang.cce.vmul(dout_x_reduce, running_std)
|
||||||
|
d_batch_std = te.lang.cce.vsub(d_batch_std_1, d_batch_std_2)
|
||||||
|
d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std)
|
||||||
|
d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std)
|
||||||
|
|
||||||
|
d_batch_mean = te.lang.cce.vmul(dout_reduce, gamma)
|
||||||
|
d_batch_mean = te.lang.cce.vdiv(d_batch_mean, batch_std)
|
||||||
|
d_batch_mean = te.lang.cce.vmuls(d_batch_mean, -1.)
|
||||||
|
|
||||||
|
d_gamma = te.lang.cce.vmul(dout_reduce, batch_mean)
|
||||||
|
d_gamma = te.lang.cce.vdiv(d_gamma, batch_std)
|
||||||
|
d_gamma = te.lang.cce.vmuls(d_gamma, -1.)
|
||||||
|
|
||||||
|
dx = te.lang.cce.vdiv(running_std, batch_std)
|
||||||
|
dx = te.lang.cce.broadcast(dx, shape_x)
|
||||||
|
dx = te.lang.cce.vmul(dx, dout)
|
||||||
|
return [d_batch_std, d_batch_mean, d_gamma, dx]
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, str)
|
||||||
|
def batchnorm_fold2_grad(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std, d_batch_std,
|
||||||
|
d_batch_mean, d_gamma, dx, kernel_name="batchnorm_fold2_grad"):
|
||||||
|
"""_BatchNormFold2Grad op """
|
||||||
|
shape = dout.get("shape")
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(shape)
|
||||||
|
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||||
|
check_list = ["float16", "float32"]
|
||||||
|
inp_dtype = dout.get("dtype").lower()
|
||||||
|
if not inp_dtype in check_list:
|
||||||
|
raise RuntimeError("Dtype of input only support float16, float32")
|
||||||
|
data_format = dout.get("format")
|
||||||
|
ori_format = dout.get("ori_format")
|
||||||
|
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||||
|
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||||
|
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||||
|
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||||
|
shape_c = gamma.get("shape")
|
||||||
|
if gamma.get("format").upper() == "NCHW":
|
||||||
|
shape_c = 1, gamma.get("shape")[0], 1, 1
|
||||||
|
|
||||||
|
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||||
|
dout_reduce_t = tvm.placeholder(shape_c, name="dout_reduce", dtype=inp_dtype)
|
||||||
|
dout_x_reduce_t = tvm.placeholder(shape_c, name="dout_x_reduce", dtype=inp_dtype)
|
||||||
|
gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype)
|
||||||
|
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
|
||||||
|
batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype)
|
||||||
|
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
|
||||||
|
|
||||||
|
res_list = batchnorm_fold2_grad_compute(dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t,
|
||||||
|
running_std_t, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res_list)
|
||||||
|
|
||||||
|
tensor_list = [dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t, running_std_t] + list(
|
||||||
|
res_list)
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""_BatchNormFold2GradReduce op"""
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from te.platform.cce_build import build_config
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
SHAPE_SIZE_LIMIT = 2147483648
|
||||||
|
|
||||||
|
batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batchnorm_fold2_grad_reduce.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batchnorm_fold2_grad_reduce") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.input(0, "dout", None, "required", None) \
|
||||||
|
.input(1, "x", None, "required", None) \
|
||||||
|
.output(0, "dout_reduce", True, "required", "all") \
|
||||||
|
.output(1, "dout_x_reduce", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batchnorm_fold2_grad_reduce_op_info)
|
||||||
|
def _batchnorm_fold2_grad_reduce_tbe():
|
||||||
|
"""_BatchNormFold2GradReduce TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("batchnorm_fold2_grad_reduce")
|
||||||
|
def batchnorm_fold2_grad_reduce_compute(dout, x, dout_args, kernel_name="batchnorm_fold2_grad_reduce"):
|
||||||
|
"""_BatchNormFold2GradReduce compute"""
|
||||||
|
dtype = dout_args.get("dtype")
|
||||||
|
dout_format = dout_args.get("format")
|
||||||
|
ori_format = dout_args.get("ori_format")
|
||||||
|
shape = dout_args.get("shape")
|
||||||
|
|
||||||
|
if dtype == "float16":
|
||||||
|
dout = te.lang.cce.cast_to(dout, "float32")
|
||||||
|
x = te.lang.cce.cast_to(x, "float32")
|
||||||
|
|
||||||
|
dout_x = te.lang.cce.vmul(dout, x)
|
||||||
|
if dout_format == "NC1HWC0":
|
||||||
|
axis = [0, 2, 3]
|
||||||
|
dout_reduce, dout_x_reduce = te.lang.cce.tuple_sum([dout, dout_x], axis, True)
|
||||||
|
else:
|
||||||
|
axis = list(range(len(shape)))
|
||||||
|
if ori_format == "NCHW":
|
||||||
|
axis.pop(1)
|
||||||
|
for _, i in enumerate(range(len(shape))):
|
||||||
|
if shape[i] == 1 and i in axis:
|
||||||
|
axis.remove(i)
|
||||||
|
dout_reduce = te.lang.cce.sum(dout, axis, False)
|
||||||
|
dout_x_reduce = te.lang.cce.sum(dout_x, axis, False)
|
||||||
|
return [dout_reduce, dout_x_reduce]
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, str)
|
||||||
|
def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name="batchnorm_fold2_grad_reduce"):
|
||||||
|
"""_BatchNormFold2GradReduce op"""
|
||||||
|
shape = x.get("shape")
|
||||||
|
x_format = x.get("format")
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(shape)
|
||||||
|
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||||
|
check_list = ["float16", "float32"]
|
||||||
|
inp_dtype = x.get("dtype").lower()
|
||||||
|
if not inp_dtype in check_list:
|
||||||
|
raise RuntimeError("Dtype of input only support float16, float32")
|
||||||
|
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||||
|
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||||
|
|
||||||
|
res_list = batchnorm_fold2_grad_reduce_compute(dout_t, x_t, dout, kernel_name)
|
||||||
|
|
||||||
|
if x_format == "NC1HWC0":
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res_list)
|
||||||
|
tensor_list = [dout_t, x_t] + list(res_list)
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
||||||
|
return
|
||||||
|
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||||
|
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
|
||||||
|
with build_config:
|
||||||
|
tvm.build(sch, tensor_list, "cce", name=kernel_name)
|
|
@ -0,0 +1,124 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""_BatchNormFoldGrad op"""
|
||||||
|
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
|
||||||
|
batch_norm_op_info = TBERegOp("BatchNormFoldGradD") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batchnorm_fold_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batchnorm_fold_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("epsilon", "optional", "float", "all") \
|
||||||
|
.attr("is_training", "optional", "bool", "all") \
|
||||||
|
.attr("freeze_bn", "optional", "int", "all") \
|
||||||
|
.input(0, "d_batch_mean", False, "required", "all") \
|
||||||
|
.input(1, "d_batch_std", False, "required", "all") \
|
||||||
|
.input(2, "x", False, "required", "all") \
|
||||||
|
.input(3, "batch_mean", False, "required", "all") \
|
||||||
|
.input(4, "batch_std", False, "required", "all") \
|
||||||
|
.output(0, "dx", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batch_norm_op_info)
|
||||||
|
def _batchnorm_fold_grad_tbe():
|
||||||
|
"""_BatchNormFoldGrad TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std):
|
||||||
|
"""_batchnorm_fold_grad_compute """
|
||||||
|
shape_x = te.lang.cce.util.shape_to_list(data_x.shape)
|
||||||
|
normal_size = shape_x[0] * shape_x[2] * shape_x[3]
|
||||||
|
|
||||||
|
d_batch_mean_broad = te.lang.cce.broadcast(d_batch_mean, shape_x)
|
||||||
|
d_batch_std_broad = te.lang.cce.broadcast(d_batch_std, shape_x)
|
||||||
|
batch_mean_broad = te.lang.cce.broadcast(batch_mean, shape_x)
|
||||||
|
batch_std_broad = te.lang.cce.broadcast(batch_std, shape_x)
|
||||||
|
|
||||||
|
dx = te.lang.cce.vsub(data_x, batch_mean_broad)
|
||||||
|
dx = te.lang.cce.vmul(dx, d_batch_std_broad)
|
||||||
|
dx = te.lang.cce.vdiv(dx, batch_std_broad)
|
||||||
|
dx = te.lang.cce.vadd(dx, d_batch_mean_broad)
|
||||||
|
dx = te.lang.cce.vmuls(dx, tvm.const(1. / normal_size, dtype=dx.dtype))
|
||||||
|
return [dx]
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict, dict,
|
||||||
|
float, bool, int, str)
|
||||||
|
def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
|
||||||
|
epsilon=1e-5, is_training=True, freeze_bn=0, kernel_name="batchnorm_fold_grad"):
|
||||||
|
"""batchnorm_fold_grad op """
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std):
|
||||||
|
util.check_shape_rule(iv.get("shape"))
|
||||||
|
util.check_tensor_shape_size(iv.get("shape"))
|
||||||
|
check_tuple = ("float16", "float32")
|
||||||
|
for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std):
|
||||||
|
util.check_dtype_rule(iv.get("dtype").lower(), check_tuple)
|
||||||
|
|
||||||
|
shape_x = x.get("shape")
|
||||||
|
dtype_x = x.get("dtype")
|
||||||
|
format_data = x.get("format").upper()
|
||||||
|
if format_data not in ("NCHW", "NC1HWC0"):
|
||||||
|
raise RuntimeError("Format of input only support 4D and 5HD")
|
||||||
|
|
||||||
|
shape_mean = d_batch_mean.get("shape")
|
||||||
|
dtype_mean = d_batch_mean.get("dtype").lower()
|
||||||
|
if format_data == "NC1HWC0":
|
||||||
|
if len(shape_x) != 5:
|
||||||
|
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||||
|
"when input format is NC1HWC0")
|
||||||
|
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||||
|
elif format_data == "NCHW":
|
||||||
|
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||||
|
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||||
|
if shape_x[1] != shape_mean[0]:
|
||||||
|
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||||
|
"be equal to the second axis of shape_x")
|
||||||
|
shape_mean = (1, shape_x[1],)
|
||||||
|
for _ in range(2, len(shape_x)):
|
||||||
|
shape_mean = shape_mean + (1,)
|
||||||
|
|
||||||
|
d_batch_mean = tvm.placeholder(shape_mean, name="d_batch_mean", dtype=dtype_mean)
|
||||||
|
d_batch_std = tvm.placeholder(shape_mean, name="d_batch_std", dtype=dtype_mean)
|
||||||
|
data_x = tvm.placeholder(shape_x, name="data_x", dtype=dtype_x.lower())
|
||||||
|
batch_mean = tvm.placeholder(shape_mean, name="batch_mean", dtype=dtype_mean)
|
||||||
|
batch_std = tvm.placeholder(shape_mean, name="batch_std", dtype=dtype_mean)
|
||||||
|
|
||||||
|
res = _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std)
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res)
|
||||||
|
|
||||||
|
tensor_list = [d_batch_mean, d_batch_std, data_x, batch_mean, batch_std] + res
|
||||||
|
config = {"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""CorrectionMul op"""
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
SHAPE_SIZE_LIMIT = 2147483648
|
||||||
|
|
||||||
|
correction_mul_op_info = TBERegOp("CorrectionMul") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("correction_mul.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("correction_mul") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
|
.input(0, "x", None, "required", None) \
|
||||||
|
.input(1, "batch_std", None, "required", None) \
|
||||||
|
.input(2, "running_std", None, "required", None) \
|
||||||
|
.output(0, "y", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(correction_mul_op_info)
|
||||||
|
def _correction_mul_tbe():
|
||||||
|
"""CorrectionMul TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("correction_mul")
|
||||||
|
def correction_mul_compute(x, batch_std, running_std, kernel_name="correction_mul"):
|
||||||
|
"""CorrectionMul compute"""
|
||||||
|
shape_x = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
|
factor = te.lang.cce.vdiv(batch_std, running_std)
|
||||||
|
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||||
|
res = te.lang.cce.vmul(x, factor_b)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, int, str)
|
||||||
|
def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correction_mul"):
|
||||||
|
"""CorrectionMul op"""
|
||||||
|
shape = x.get("shape")
|
||||||
|
data_format = x.get("format")
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(shape)
|
||||||
|
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||||
|
check_list = ["float16", "float32"]
|
||||||
|
inp_dtype = x.get("dtype").lower()
|
||||||
|
if not inp_dtype in check_list:
|
||||||
|
raise RuntimeError("Dtype of input only support float16, float32")
|
||||||
|
|
||||||
|
# shape = util.shape_refine(shape)
|
||||||
|
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||||
|
shape_c = [1] * len(shape)
|
||||||
|
shape_c[channel] = batch_std.get("ori_shape")[0]
|
||||||
|
if data_format == "NC1HWC0" and channel == 1:
|
||||||
|
shape_c = batch_std.get("shape")
|
||||||
|
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
|
||||||
|
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
|
||||||
|
res = correction_mul_compute(x_t, batch_std_t, running_std_t, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res)
|
||||||
|
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": [x_t, batch_std_t, running_std_t, res]}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,134 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""CorrectionMul op"""
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
SHAPE_SIZE_LIMIT = 2147483648
|
||||||
|
|
||||||
|
correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("correction_mul_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("correction_mul_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
|
.input(0, "dout", None, "required", None) \
|
||||||
|
.input(1, "x", None, "required", None) \
|
||||||
|
.input(2, "batch_std", None, "required", None) \
|
||||||
|
.input(3, "running_std", None, "required", None) \
|
||||||
|
.output(0, "dx", True, "required", "all") \
|
||||||
|
.output(1, "d_batch_std", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(correction_mul_grad_op_info)
|
||||||
|
def _correction_mul_grad_tbe():
|
||||||
|
"""CorrectionMulGrad TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("correction_mul_grad")
|
||||||
|
def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_format, kernel_name="correction_mul"):
|
||||||
|
"""CorrectionMulGrad compute"""
|
||||||
|
shape_x = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
|
factor = te.lang.cce.vdiv(batch_std, running_std)
|
||||||
|
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||||
|
dx = te.lang.cce.vmul(dout, factor_b)
|
||||||
|
mul_data = te.lang.cce.vmul(dout, x)
|
||||||
|
if channel == 0:
|
||||||
|
if data_format == "NCHW":
|
||||||
|
axis = [1, 2, 3]
|
||||||
|
else:
|
||||||
|
axis = [1, 2, 3, 4]
|
||||||
|
else:
|
||||||
|
axis = [2, 3]
|
||||||
|
red_data = te.lang.cce.sum(mul_data, axis, keepdims=True)
|
||||||
|
d_batch_std = te.lang.cce.vdiv(red_data, running_std)
|
||||||
|
return [dx, d_batch_std]
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
|
||||||
|
def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"):
|
||||||
|
"""CorrectionMulGrad op"""
|
||||||
|
shape_dout = dout.get("shape")
|
||||||
|
shape_x = dout.get("shape")
|
||||||
|
|
||||||
|
dtype_dout = dout.get("dtype")
|
||||||
|
dtype_x = x.get("dtype")
|
||||||
|
dtype_batch_std = batch_std.get("dtype")
|
||||||
|
dtype_running_std = running_std.get("dtype")
|
||||||
|
|
||||||
|
inp_dtype_dout = dtype_dout.lower()
|
||||||
|
inp_dtype_x = dtype_x.lower()
|
||||||
|
inp_dtype_batch_std = dtype_batch_std.lower()
|
||||||
|
inp_dtype_running_std = dtype_running_std.lower()
|
||||||
|
|
||||||
|
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
|
||||||
|
util.check_dtype_rule(inp_dtype_x, ("float16", "float32"))
|
||||||
|
util.check_dtype_rule(inp_dtype_batch_std, ("float32",))
|
||||||
|
util.check_dtype_rule(inp_dtype_running_std, ("float32",))
|
||||||
|
util.compare_tensor_dict_key(dout, x, "dtype")
|
||||||
|
util.compare_tensor_dict_key(dout, x, "shape")
|
||||||
|
util.compare_tensor_dict_key(dx, x, "shape")
|
||||||
|
util.compare_tensor_dict_key(batch_std, running_std, "shape")
|
||||||
|
util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")
|
||||||
|
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(shape_x)
|
||||||
|
util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
|
||||||
|
|
||||||
|
data_format = dout.get("format")
|
||||||
|
ori_format = dout.get("format")
|
||||||
|
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||||
|
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||||
|
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||||
|
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||||
|
|
||||||
|
shape_c = [1] * len(shape_x)
|
||||||
|
shape_c[channel] = batch_std.get("ori_shape")[0]
|
||||||
|
if data_format == "NC1HWC0" and channel == 1:
|
||||||
|
shape_c = batch_std.get("shape")
|
||||||
|
|
||||||
|
dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
|
||||||
|
x_t = tvm.placeholder(shape_x, name="x", dtype=inp_dtype_x)
|
||||||
|
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype_batch_std)
|
||||||
|
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype_running_std)
|
||||||
|
res_list = correction_mul_grad_compute(dout_t, x_t, batch_std_t, running_std_t, channel, data_format, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res_list)
|
||||||
|
|
||||||
|
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""FakeQuantWithMinMax op"""
|
||||||
|
|
||||||
|
from functools import reduce as functools_reduce
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("fake_quant_with_min_max_vars_ema.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("fake_quant_with_min_max_vars_ema") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("ema", "optional", "bool", "all") \
|
||||||
|
.attr("ema_decay", "optional", "float", "all") \
|
||||||
|
.attr("symmetric", "optional", "bool", "all") \
|
||||||
|
.attr("narrow_range", "optional", "bool", "all") \
|
||||||
|
.attr("training", "optional", "bool", "all") \
|
||||||
|
.attr("num_bits", "optional", "int", "all") \
|
||||||
|
.attr("quant_delay", "optional", "int", "all") \
|
||||||
|
.input(0, "x", None, "required", None) \
|
||||||
|
.input(1, "min", None, "required", None) \
|
||||||
|
.input(2, "max", None, "required", None) \
|
||||||
|
.output(0, "y", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(fake_quant_op_info)
|
||||||
|
def _fake_quant_tbe():
|
||||||
|
"""FakeQuantWithMinMax TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("fake_quant_with_min_max_vars_ema")
|
||||||
|
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max,
|
||||||
|
kernel_name="correction_mul"):
|
||||||
|
"""FakeQuantWithMinMax"""
|
||||||
|
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
|
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||||
|
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
|
||||||
|
quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype)
|
||||||
|
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||||
|
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||||
|
|
||||||
|
# CalNudge(NudgeMinMax)
|
||||||
|
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
|
||||||
|
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
|
||||||
|
# Nudge zero point
|
||||||
|
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
|
||||||
|
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
|
||||||
|
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
|
||||||
|
|
||||||
|
# boradcast to shape
|
||||||
|
nudge_min = te.lang.cce.broadcast(nudge_min, shape, x.dtype)
|
||||||
|
nudge_max = te.lang.cce.broadcast(nudge_max, shape, x.dtype)
|
||||||
|
scale = te.lang.cce.broadcast(scale, shape, x.dtype)
|
||||||
|
|
||||||
|
# FakeQuant
|
||||||
|
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
|
||||||
|
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale),
|
||||||
|
0.5))
|
||||||
|
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
|
||||||
|
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
|
||||||
|
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
|
||||||
|
kernel_name="fake_quant"):
|
||||||
|
"""FakeQuantWithMinMax"""
|
||||||
|
input_shape = x.get("shape")
|
||||||
|
input_dtype = x.get("dtype")
|
||||||
|
min_shape = min_val.get("ori_shape")
|
||||||
|
min_dtype = min_val.get("dtype")
|
||||||
|
max_shape = max_val.get("ori_shape")
|
||||||
|
max_dtype = max_val.get("dtype")
|
||||||
|
|
||||||
|
min_shape = util.scalar2tensor_one(min_shape)
|
||||||
|
max_shape = util.scalar2tensor_one(max_shape)
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(input_shape)
|
||||||
|
util.check_shape_rule(min_shape, 1, 1, 1)
|
||||||
|
util.check_shape_rule(max_shape, 1, 1, 1)
|
||||||
|
util.check_tensor_shape_size(input_shape)
|
||||||
|
util.check_tensor_shape_size(min_shape)
|
||||||
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
|
||||||
|
check_list = ["float32", "float16"]
|
||||||
|
x_dtype = input_dtype.lower()
|
||||||
|
min_dtype = min_dtype.lower()
|
||||||
|
max_dtype = max_dtype.lower()
|
||||||
|
util.check_dtype_rule(x_dtype, check_list)
|
||||||
|
util.check_dtype_rule(min_dtype, check_list)
|
||||||
|
util.check_dtype_rule(max_dtype, check_list)
|
||||||
|
|
||||||
|
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||||
|
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||||
|
|
||||||
|
if symmetric:
|
||||||
|
quant_min = 0 - 2 ** (num_bits - 1)
|
||||||
|
quant_max = 2 ** (num_bits - 1) - 1
|
||||||
|
else:
|
||||||
|
quant_min = 0
|
||||||
|
quant_max = 2 ** num_bits - 1
|
||||||
|
if narrow_range:
|
||||||
|
quant_min = quant_min + 1
|
||||||
|
|
||||||
|
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||||
|
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||||
|
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||||
|
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y,
|
||||||
|
quant_min, quant_max, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res)
|
||||||
|
|
||||||
|
tensor_list = [input_data, min_data, max_data, res]
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,156 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""FakeQuantWithMinMaxGrad op"""
|
||||||
|
|
||||||
|
from functools import reduce as functools_reduce
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
SHAPE_SIZE_LIMIT = 2147483648
|
||||||
|
D_TYPE = 'float32'
|
||||||
|
|
||||||
|
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("fake_quant_with_min_max_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("fake_quant_with_min_max_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("num_bits", "optional", "int", "all") \
|
||||||
|
.attr("quant_delay", "optional", "int", "all") \
|
||||||
|
.input(0, "dout", None, "required", None) \
|
||||||
|
.input(1, "x", None, "required", None) \
|
||||||
|
.input(2, "min", None, "required", None) \
|
||||||
|
.input(3, "max", None, "required", None) \
|
||||||
|
.output(0, "dx", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
|
DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
def _less_compare_float32(data_x, data_y):
|
||||||
|
"""_less_compare_float32 compute"""
|
||||||
|
shape_inputs = te.lang.cce.util.shape_to_list(data_x.shape)
|
||||||
|
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
|
||||||
|
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
|
||||||
|
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
|
||||||
|
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
|
||||||
|
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
|
||||||
|
|
||||||
|
res_sub = te.lang.cce.vsub(data_y, data_x)
|
||||||
|
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
|
||||||
|
res_max = te.lang.cce.vmax(res_min, data_zero)
|
||||||
|
|
||||||
|
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
|
||||||
|
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
|
||||||
|
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(fake_quant_grad_op_info)
|
||||||
|
def _fake_quant_grad_tbe():
|
||||||
|
"""FakeQuantWithMinMaxGrad TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("fake_quant_with_min_max_grad")
|
||||||
|
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
|
||||||
|
kernel_name="fake_quant_with_min_max_grad"):
|
||||||
|
"""FakeQuantWithMinMaxGrad"""
|
||||||
|
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
|
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||||
|
quant_min = tvm.const(quant_min, x.dtype)
|
||||||
|
quant_max = tvm.const(quant_max, x.dtype)
|
||||||
|
quant_min = te.lang.cce.broadcast(quant_min, shape_min)
|
||||||
|
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
|
||||||
|
|
||||||
|
# CalNudge(NudgeMinMax)
|
||||||
|
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
|
||||||
|
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
|
||||||
|
# Nudge zero point
|
||||||
|
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
|
||||||
|
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
|
||||||
|
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
|
||||||
|
nudge_min = te.lang.cce.broadcast(nudge_min, shape)
|
||||||
|
nudge_max = te.lang.cce.broadcast(nudge_max, shape)
|
||||||
|
|
||||||
|
bool_over_min = _less_compare_float32(nudge_min, x)
|
||||||
|
bool_less_max = _less_compare_float32(x, nudge_max)
|
||||||
|
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
|
||||||
|
res = te.lang.cce.vmul(dout, bool_between)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict, int, int, str)
|
||||||
|
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay,
|
||||||
|
kernel_name="fake_quant_with_min_max_grad"):
|
||||||
|
"""FakeQuantWithMinMaxGrad"""
|
||||||
|
input_shape = x.get("shape")
|
||||||
|
input_dtype = x.get("dtype")
|
||||||
|
min_shape = min_val.get("ori_shape")
|
||||||
|
min_dtype = min_val.get("dtype")
|
||||||
|
max_shape = max_val.get("ori_shape")
|
||||||
|
max_dtype = max_val.get("dtype")
|
||||||
|
|
||||||
|
min_shape = util.scalar2tensor_one(min_shape)
|
||||||
|
max_shape = util.scalar2tensor_one(max_shape)
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(input_shape)
|
||||||
|
util.check_shape_rule(min_shape, 1, 1, 1)
|
||||||
|
util.check_shape_rule(max_shape, 1, 1, 1)
|
||||||
|
util.check_tensor_shape_size(input_shape)
|
||||||
|
util.check_tensor_shape_size(min_shape)
|
||||||
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
|
||||||
|
check_list = ["float32", 'float16']
|
||||||
|
x_dtype = input_dtype.lower()
|
||||||
|
min_dtype = min_dtype.lower()
|
||||||
|
max_dtype = max_dtype.lower()
|
||||||
|
util.check_dtype_rule(x_dtype, check_list)
|
||||||
|
util.check_dtype_rule(min_dtype, check_list)
|
||||||
|
util.check_dtype_rule(max_dtype, check_list)
|
||||||
|
|
||||||
|
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||||
|
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||||
|
|
||||||
|
quant_min = 0
|
||||||
|
quant_max = 2 ** num_bits - 1
|
||||||
|
dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype)
|
||||||
|
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||||
|
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||||
|
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||||
|
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
|
||||||
|
quant_max, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res)
|
||||||
|
|
||||||
|
tensor_list = [dout_data, input_data, min_data, max_data, res]
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -0,0 +1,137 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""FakeQuantWithMinMaxUpdate op"""
|
||||||
|
from functools import reduce as functools_reduce
|
||||||
|
import te.lang.cce
|
||||||
|
from te import tvm
|
||||||
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
from topi import generic
|
||||||
|
from topi.cce import util
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
|
||||||
|
fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("fake_quant_with_min_max_update5d.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("fake_quant_with_min_max_update") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("ema", "optional", "bool", "all") \
|
||||||
|
.attr("ema_decay", "optional", "float", "all") \
|
||||||
|
.attr("symmetric", "optional", "bool", "all") \
|
||||||
|
.attr("narrow_range", "optional", "bool", "all") \
|
||||||
|
.attr("training", "optional", "bool", "all") \
|
||||||
|
.attr("num_bits", "optional", "int", "all") \
|
||||||
|
.attr("quant_delay", "optional", "int", "all") \
|
||||||
|
.input(0, "x", None, "required", None) \
|
||||||
|
.input(1, "min", None, "required", None) \
|
||||||
|
.input(2, "max", None, "required", None) \
|
||||||
|
.output(0, "min_up", True, "required", "all") \
|
||||||
|
.output(1, "max_up", True, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(fake_quant_update5d_op_info)
|
||||||
|
def _fake_quant_update5d_tbe():
|
||||||
|
"""_FakeQuantWithMinMaxUpdate5D TBE register"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@fusion_manager.register("fake_quant_with_min_max_update")
|
||||||
|
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
|
||||||
|
kernel_name="fake_quant_update"):
|
||||||
|
"""FakeQuantWithMinMaxUpdate compute"""
|
||||||
|
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
|
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||||
|
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||||
|
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||||
|
if not ema:
|
||||||
|
ema_decay = 0.0
|
||||||
|
if training:
|
||||||
|
# CalMinMax
|
||||||
|
axis = tuple(range(len(shape)))
|
||||||
|
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||||
|
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||||
|
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||||
|
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||||
|
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||||
|
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||||
|
min_val = te.lang.cce.vmins(min_val, 0)
|
||||||
|
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||||
|
|
||||||
|
return [min_val, max_val]
|
||||||
|
|
||||||
|
|
||||||
|
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
|
||||||
|
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
|
||||||
|
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
|
||||||
|
kernel_name="fake_quant_update"):
|
||||||
|
"""FakeQuantWithMinMax op"""
|
||||||
|
input_shape = x.get("shape")
|
||||||
|
input_dtype = x.get("dtype")
|
||||||
|
min_shape = min_val.get("ori_shape")
|
||||||
|
min_dtype = min_val.get("dtype")
|
||||||
|
max_shape = max_val.get("ori_shape")
|
||||||
|
max_dtype = max_val.get("dtype")
|
||||||
|
|
||||||
|
min_shape = util.scalar2tensor_one(min_shape)
|
||||||
|
max_shape = util.scalar2tensor_one(max_shape)
|
||||||
|
util.check_kernel_name(kernel_name)
|
||||||
|
util.check_shape_rule(input_shape)
|
||||||
|
util.check_shape_rule(min_shape, 1, 1, 1)
|
||||||
|
util.check_shape_rule(max_shape, 1, 1, 1)
|
||||||
|
util.check_tensor_shape_size(input_shape)
|
||||||
|
util.check_tensor_shape_size(min_shape)
|
||||||
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
|
||||||
|
check_list = ["float32", "float16"]
|
||||||
|
x_dtype = input_dtype.lower()
|
||||||
|
min_dtype = min_dtype.lower()
|
||||||
|
max_dtype = max_dtype.lower()
|
||||||
|
util.check_dtype_rule(x_dtype, check_list)
|
||||||
|
util.check_dtype_rule(min_dtype, check_list)
|
||||||
|
util.check_dtype_rule(max_dtype, check_list)
|
||||||
|
|
||||||
|
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||||
|
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||||
|
|
||||||
|
if symmetric:
|
||||||
|
quant_min = 0 - 2 ** (num_bits - 1)
|
||||||
|
quant_max = 2 ** (num_bits - 1) - 1
|
||||||
|
else:
|
||||||
|
quant_min = 0
|
||||||
|
quant_max = 2 ** num_bits - 1
|
||||||
|
if narrow_range:
|
||||||
|
quant_min = quant_min + 1
|
||||||
|
|
||||||
|
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||||
|
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||||
|
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||||
|
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data,
|
||||||
|
ema, ema_decay, quant_min, quant_max, training, kernel_name)
|
||||||
|
|
||||||
|
with tvm.target.cce():
|
||||||
|
sch = generic.auto_schedule(res_list)
|
||||||
|
|
||||||
|
tensor_list = [input_data, min_data, max_data] + list(res_list)
|
||||||
|
config = {"print_ir": False,
|
||||||
|
"name": kernel_name,
|
||||||
|
"tensor_list": tensor_list}
|
||||||
|
|
||||||
|
te.lang.cce.cce_build_code(sch, config)
|
|
@ -30,6 +30,10 @@ __all__ = ["FakeQuantWithMinMax",
|
||||||
"CorrectionMulGrad",
|
"CorrectionMulGrad",
|
||||||
"BatchNormFold2",
|
"BatchNormFold2",
|
||||||
"BatchNormFold2Grad",
|
"BatchNormFold2Grad",
|
||||||
|
"BatchNormFoldD",
|
||||||
|
"BNTrainingReduce",
|
||||||
|
"BatchNormFold2_D",
|
||||||
|
"FakeQuantWithMinMaxUpdate",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +170,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
|
||||||
>>> result = fake_quant(input_x, _min, _max)
|
>>> result = fake_quant(input_x, _min, _max)
|
||||||
"""
|
"""
|
||||||
support_quant_bit = [4, 8]
|
support_quant_bit = [4, 8]
|
||||||
channel_idx = 0
|
channel_axis = 0
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
|
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
|
||||||
|
@ -188,8 +192,8 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||||
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name)
|
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name)
|
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_type, min_type, max_type):
|
def infer_dtype(self, x_type, min_type, max_type):
|
||||||
|
@ -272,7 +276,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
||||||
>>> global_step = Tensor(np.arange(6), mindspore.int32)
|
>>> global_step = Tensor(np.arange(6), mindspore.int32)
|
||||||
>>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
|
>>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
|
||||||
"""
|
"""
|
||||||
channel = 1
|
channel_axis = 1
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||||
|
@ -287,7 +291,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
|
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
|
||||||
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
|
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
|
||||||
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name)
|
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||||
return mean_shape, mean_shape, mean_shape, mean_shape
|
return mean_shape, mean_shape, mean_shape, mean_shape
|
||||||
|
|
||||||
|
@ -314,7 +318,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
||||||
>>> global_step = Tensor([2], mindspore.int32)
|
>>> global_step = Tensor([2], mindspore.int32)
|
||||||
>>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
|
>>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
|
||||||
"""
|
"""
|
||||||
channel = 1
|
channel_axis = 1
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||||
|
@ -333,8 +337,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
||||||
"batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
"batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
||||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||||
"batch_std shape", batch_std_shape, Rel.EQ, self.name)
|
"batch_std shape", batch_std_shape, Rel.EQ, self.name)
|
||||||
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ,
|
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
|
||||||
self.name)
|
"input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
|
@ -368,17 +372,17 @@ class CorrectionMul(PrimitiveWithInfer):
|
||||||
>>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
|
>>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
|
||||||
>>> out = correction_mul(input_x, batch_std, running_std)
|
>>> out = correction_mul(input_x, batch_std, running_std)
|
||||||
"""
|
"""
|
||||||
channel = 0
|
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self, channel_axis=0):
|
||||||
"""init correction mul layer"""
|
"""init correction mul layer"""
|
||||||
|
self.channel_axis = channel_axis
|
||||||
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
|
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
|
||||||
outputs=['out'])
|
outputs=['out'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
|
def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
|
||||||
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel],
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
|
||||||
Rel.EQ, self.name)
|
Rel.EQ, self.name)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
|
@ -400,20 +404,20 @@ class CorrectionMulGrad(PrimitiveWithInfer):
|
||||||
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
|
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
|
||||||
>>> result = correction_mul_grad(dout, input_x, gamma, running_std)
|
>>> result = correction_mul_grad(dout, input_x, gamma, running_std)
|
||||||
"""
|
"""
|
||||||
channel = 0
|
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self, channel_axis=0):
|
||||||
"""init correction mul layer"""
|
"""init correction mul layer"""
|
||||||
|
self.channel_axis = channel_axis
|
||||||
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
|
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
|
||||||
outputs=['dx', 'd_gamma'])
|
outputs=['dx', 'd_gamma'])
|
||||||
|
|
||||||
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
|
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
|
||||||
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
|
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
|
||||||
validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel],
|
validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
|
||||||
Rel.EQ, self.name)
|
|
||||||
validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel],
|
|
||||||
Rel.EQ, self.name)
|
Rel.EQ, self.name)
|
||||||
|
validator.check("running_std_shape[0]", running_std_shape[0],
|
||||||
|
"dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
return x_shape, gamma_shape
|
return x_shape, gamma_shape
|
||||||
|
|
||||||
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
|
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
|
||||||
|
@ -454,7 +458,7 @@ class BatchNormFold2(PrimitiveWithInfer):
|
||||||
>>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
|
>>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
|
||||||
>>> running_std, running_mean, global_step)
|
>>> running_std, running_mean, global_step)
|
||||||
"""
|
"""
|
||||||
channel = 1
|
channel_axis = 1
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, freeze_bn=0):
|
def __init__(self, freeze_bn=0):
|
||||||
|
@ -471,7 +475,7 @@ class BatchNormFold2(PrimitiveWithInfer):
|
||||||
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel],
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
|
||||||
Rel.EQ, self.name)
|
Rel.EQ, self.name)
|
||||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
@ -501,7 +505,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
|
||||||
>>> global_step = Tensor(np.array([-2]), mindspore.int32)
|
>>> global_step = Tensor(np.array([-2]), mindspore.int32)
|
||||||
>>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
|
>>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
|
||||||
"""
|
"""
|
||||||
channel = 1
|
channel_axis = 1
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, freeze_bn=0):
|
def __init__(self, freeze_bn=0):
|
||||||
|
@ -519,7 +523,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
|
||||||
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
|
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
|
||||||
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel],
|
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
|
||||||
Rel.EQ, self.name)
|
Rel.EQ, self.name)
|
||||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||||
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
|
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
|
||||||
|
@ -542,3 +546,259 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
|
||||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||||
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
|
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
|
||||||
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
|
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormFoldD(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of _BatchNormFold operation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||||
|
"""init _BatchNormFold layer"""
|
||||||
|
from mindspore.ops._op_impl._custom_op import batchnorm_fold
|
||||||
|
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||||
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||||
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||||
|
self.data_format = "NCHW"
|
||||||
|
self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
|
||||||
|
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
|
||||||
|
'mean_updated', 'variance_updated'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
|
||||||
|
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
|
||||||
|
return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
|
||||||
|
validator.check("input type", x_type, "mean type", mean_type)
|
||||||
|
validator.check("input type", x_type, "variance type", variance_type)
|
||||||
|
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||||
|
return x_type, x_type, x_type, x_type, x_type, x_type, x_type
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormFoldGradD(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of _BatchNormFoldGrad operation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||||
|
"""init _BatchNormFoldGrad layer"""
|
||||||
|
from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
|
||||||
|
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||||
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||||
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||||
|
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
|
||||||
|
outputs=['dx'])
|
||||||
|
|
||||||
|
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
|
||||||
|
validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
|
||||||
|
validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
|
||||||
|
validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
|
||||||
|
validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
|
||||||
|
validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
|
||||||
|
validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
|
||||||
|
validator.check("input type", x_type, "batch_mean type", batch_mean_type)
|
||||||
|
validator.check("input type", x_type, "batch_std type", batch_std_type)
|
||||||
|
args = {"input type": x_type}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||||
|
return x_type
|
||||||
|
|
||||||
|
|
||||||
|
class BNTrainingReduce(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
reduce sum at axis [0, 2, 3].
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **x_sum** (Tensor) - Tensor has the same shape as x.
|
||||||
|
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init _BNTrainingReduce layer"""
|
||||||
|
self.init_prim_io_names(inputs=['x'],
|
||||||
|
outputs=['x_sum', 'x_square_sum'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return [x_shape[1]], [x_shape[1]]
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type):
|
||||||
|
return x_type, x_type
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormFold2_D(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Scale the bias with a correction factor to the long term statistics
|
||||||
|
prior to quantization. This ensures that there is no jitter in the quantized bias
|
||||||
|
due to batch to batch variation.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||||
|
- **beta** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||||
|
- **global_step** (Tensor) - Tensor to record current global step.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **y** (Tensor) - Tensor has the same shape as x.
|
||||||
|
|
||||||
|
"""
|
||||||
|
channel_axis = 1
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, freeze_bn=0):
|
||||||
|
"""init conv2d fold layer"""
|
||||||
|
from mindspore.ops._op_impl._custom_op import batchnorm_fold2
|
||||||
|
self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
|
||||||
|
outputs=['y'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
|
||||||
|
Rel.EQ, self.name)
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
|
||||||
|
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
|
||||||
|
"beta": beta_type, "gamma": gamma_type, "x": x_type}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||||
|
return x_type
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormFold2GradD(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of CorrectionAddGrad operation."""
|
||||||
|
channel_axis = 1
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, freeze_bn=False):
|
||||||
|
"""init MulFold layer"""
|
||||||
|
from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
|
||||||
|
self.freeze_bn = freeze_bn
|
||||||
|
self.init_prim_io_names(
|
||||||
|
inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
|
||||||
|
outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
|
||||||
|
|
||||||
|
def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
|
||||||
|
batch_mean_shape, running_std_shape):
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
|
||||||
|
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
|
||||||
|
Rel.EQ, self.name)
|
||||||
|
return gamma_shape, gamma_shape, gamma_shape, dout_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
|
||||||
|
batch_mean_type, running_std_type):
|
||||||
|
validator.check("batch_std type", batch_std_type,
|
||||||
|
"batch_mean type", batch_mean_type)
|
||||||
|
validator.check("batch_std type", batch_std_type,
|
||||||
|
"gamma type", gamma_type)
|
||||||
|
validator.check("batch_std type", batch_std_type,
|
||||||
|
"running_std type", running_std_type)
|
||||||
|
validator.check("batch_std_type", batch_std_type,
|
||||||
|
"dout type", dout_type)
|
||||||
|
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
|
||||||
|
"running_std": running_std_type, "dout": dout_type}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||||
|
return gamma_type, gamma_type, gamma_type, gamma_type
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormFold2GradReduce(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of CorrectionAddGrad operation."""
|
||||||
|
channel_axis = 1
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, freeze_bn=False):
|
||||||
|
"""init MulFold layer"""
|
||||||
|
from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
|
||||||
|
self.freeze_bn = freeze_bn
|
||||||
|
self.init_prim_io_names(inputs=['dout', 'x'],
|
||||||
|
outputs=['dout_reduce', 'dout_x_reduce'])
|
||||||
|
|
||||||
|
def infer_shape(self, dout_shape, x_shape):
|
||||||
|
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
|
||||||
|
return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
|
||||||
|
|
||||||
|
def infer_dtype(self, dout_type, x_type):
|
||||||
|
validator.check("dout type", dout_type, "x type", x_type)
|
||||||
|
return dout_type, dout_type
|
||||||
|
|
||||||
|
|
||||||
|
class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Simulate the quantize and dequantize operations in training time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||||
|
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||||
|
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||||
|
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
|
||||||
|
simulate aware quantize funcion. After delay step in training time begin simulate the aware
|
||||||
|
quantize funcion. Default: 0.
|
||||||
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
training (bool): Training the network or not. Default: True.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||||
|
- **min** (Tensor) : Value of the min range of the input data x.
|
||||||
|
- **max** (Tensor) : Value of the max range of the input data x.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- Tensor: Simulate quantize tensor of x.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||||
|
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||||
|
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||||
|
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||||
|
"""
|
||||||
|
support_quant_bit = [4, 7, 8]
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
|
||||||
|
training=True):
|
||||||
|
"""init FakeQuantWithMinMax OP"""
|
||||||
|
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
|
||||||
|
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
|
||||||
|
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
|
||||||
|
if num_bits not in self.support_quant_bit:
|
||||||
|
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||||
|
if ema and not ema_decay:
|
||||||
|
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||||
|
|
||||||
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||||
|
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
|
||||||
|
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
|
||||||
|
self.training = validator.check_value_type('training', training, (bool,), self.name)
|
||||||
|
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
||||||
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
||||||
|
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||||
|
outputs=['min_up', 'max_up'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||||
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||||
|
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
|
||||||
|
return min_shape, max_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type, min_type, max_type):
|
||||||
|
valid_types = (mstype.float16, mstype.float32)
|
||||||
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||||
|
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
|
||||||
|
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
|
||||||
|
return min_type, max_type
|
||||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import nn
|
||||||
from mindspore.nn.layer import combined
|
from mindspore.nn.layer import combined
|
||||||
from mindspore.train.quant import quant as qat
|
from mindspore.train.quant import quant as qat
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
|
||||||
class LeNet5(nn.Cell):
|
class LeNet5(nn.Cell):
|
||||||
|
@ -64,7 +64,7 @@ class LeNet5(nn.Cell):
|
||||||
x = self.fc3(x)
|
x = self.fc3(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
"""
|
||||||
def test_qat_lenet():
|
def test_qat_lenet():
|
||||||
net = LeNet5()
|
net = LeNet5()
|
||||||
net = qat.convert_quant_network(
|
net = qat.convert_quant_network(
|
||||||
|
@ -92,3 +92,4 @@ def test_qat_mobile_train():
|
||||||
net = nn.WithLossCell(net, loss)
|
net = nn.WithLossCell(net, loss)
|
||||||
net = nn.TrainOneStepCell(net, optimizer)
|
net = nn.TrainOneStepCell(net, optimizer)
|
||||||
net(img, label)
|
net(img, label)
|
||||||
|
"""
|
Loading…
Reference in New Issue