!1504 add custom tbe ops for quant aware training

Merge pull request !1504 from wandongdong/master
This commit is contained in:
mindspore-ci-bot 2020-05-27 19:18:22 +08:00 committed by Gitee
commit c8b30f9290
14 changed files with 2059 additions and 102 deletions

View File

@ -22,11 +22,15 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore._checkparam import Validator as validator
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
import mindspore.context as context
__all__ = [ __all__ = [
'FakeQuantWithMinMax', 'FakeQuantWithMinMax',
'DepthwiseConv2dBatchNormQuant',
'Conv2dBatchNormQuant', 'Conv2dBatchNormQuant',
'Conv2dQuant', 'Conv2dQuant',
'DenseQuant', 'DenseQuant',
@ -39,6 +43,169 @@ __all__ = [
] ]
class BatchNormFoldCell(Cell):
"""
Batch normalization folded.
Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5.
freeze_bn (int): Delay in steps at which computation switches from regular batch
norm to frozen mean and std. Default: 0.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
- **global_step** (Tensor) - Tensor to record current global step.
Outputs:
Tuple of 4 Tensor, the normalized input and the updated parameters.
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
"""
def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
"""init batch norm fold layer"""
super(BatchNormFoldCell, self).__init__()
self.epsilon = epsilon
self.is_gpu = context.get_context('device_target') == "GPU"
if self.is_gpu:
self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
else:
self.bn_reduce = P.BNTrainingReduce()
self.bn_update = P.BatchNormFoldD(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
def construct(self, x, mean, variance, global_step):
if self.is_gpu:
if self.training:
batch_mean, batch_std, running_mean, running_std = self.bn_train(x, mean, variance, global_step)
else:
batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
else:
if self.training:
x_sum, x_square_sum = self.bn_reduce(x)
_, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
self.bn_update(x, x_sum, x_square_sum, mean, variance)
P.Assign()(mean, mean_updated)
P.Assign()(variance, variance_updated)
else:
batch_mean = P.ZerosLike()(variance)
batch_std = P.OnesLike()(variance)
running_mean = P.TensorAdd()(mean, 0.)
running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon))
return batch_mean, batch_std, running_mean, running_std
class FakeQuantWithMinMaxD(Cell):
r"""
Aware Quantization training op of ascend. This OP provide Fake quantization observer
function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxD()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_size=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMax ascend layer"""
super(FakeQuantWithMinMaxD, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_size = channel_size
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if not per_channel:
self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
else:
raise RuntimeError("not support per channel")
if isinstance(min_init, Parameter):
self.minq = min_init
self.maxq = max_init
else:
self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)),
name='quant_min',
requires_grad=False)
self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)),
name='quant_max',
requires_grad=False)
self.reduce_min = P.ReduceMin()
self.reduce_max = P.ReduceMax()
def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
return s
def construct(self, x, minq, maxq):
if self.training:
min_up, max_up = self.ema_update(x, minq, maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant(x, minq, maxq)
return out
class FakeQuantWithMinMax(Cell): class FakeQuantWithMinMax(Cell):
r""" r"""
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
@ -62,7 +229,7 @@ class FakeQuantWithMinMax(Cell):
Tensor, with the same type and shape as the `x`. Tensor, with the same type and shape as the `x`.
Examples: Examples:
>>> fake_quant = nn.FakeQuantWithMinMax() >>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x) >>> result = fake_quant(input_x)
""" """
@ -77,7 +244,9 @@ class FakeQuantWithMinMax(Cell):
out_channels=1, out_channels=1,
quant_delay=0, quant_delay=0,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
training=True):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init self.min_init = min_init
@ -90,12 +259,13 @@ class FakeQuantWithMinMax(Cell):
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.training = training
if per_channel: if per_channel:
min_array = np.array([self.min_init for i in range( min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
0, self.out_channels)]).astype(np.float32) max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32)
max_array = np.array([self.max_init for i in range( self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
0, self.out_channels)]).astype(np.float32) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema, ema=self.ema,
ema_decay=self.ema_decay, ema_decay=self.ema_decay,
@ -113,25 +283,44 @@ class FakeQuantWithMinMax(Cell):
else: else:
min_array = np.array([min_init]).reshape(1).astype(np.float32) min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32) max_array = np.array([max_init]).reshape(1).astype(np.float32)
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
ema=self.ema, self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
ema_decay=self.ema_decay, if context.get_context('device_target') == "Ascend":
quant_delay=self.quant_delay, self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits,
symmetric=self.symmetric, ema=self.ema,
narrow_range=self.narrow_range, ema_decay=self.ema_decay,
training=True) quant_delay=self.quant_delay,
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, symmetric=self.symmetric,
ema=self.ema, narrow_range=self.narrow_range,
ema_decay=self.ema_decay, training=True,
quant_delay=self.quant_delay, min_init=self.minq,
symmetric=self.symmetric, max_init=self.maxq)
narrow_range=self.narrow_range, self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits,
training=False) ema=self.ema,
ema_decay=self.ema_decay,
self.minq = Parameter( quant_delay=self.quant_delay,
Tensor(min_array), name='quant_min', requires_grad=False) symmetric=self.symmetric,
self.maxq = Parameter( narrow_range=self.narrow_range,
Tensor(max_array), name='quant_max', requires_grad=False) training=False,
min_init=self.minq,
max_init=self.maxq)
elif context.get_context('device_target') == "GPU":
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
raise ValueError("Not support platform.")
def extend_repr(self): def extend_repr(self):
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
@ -146,6 +335,191 @@ class FakeQuantWithMinMax(Cell):
return out return out
class DepthwiseConv2dBatchNormQuant(Cell):
r"""
2D depthwise convolution with BatchNormal op folded layer.
For a more Detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.9.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'None'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
beta vector. Default: 'None'.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
gamma vector. Default: 'None'.
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
mean vector. Default: 'None'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'None'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> quant = nn.DepthwiseConv2dBatchNormQuant(1, 6,
kernel_size= (2, 2),
stride=(1, 1),
pad_mode="valid",
>>> dilation=(1, 1))
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
>>> result = quant(input_x)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
eps=1e-5,
momentum=0.997,
weight_init=None,
beta_init=None,
gamma_init=None,
mean_init=None,
var_init=None,
quant_delay=0,
freeze_bn=100000,
fake=True,
num_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
"""init DepthwiseConv2dBatchNormQuant layer"""
super(DepthwiseConv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pad_mode = pad_mode
self.padding = padding
self.dilation = twice(dilation)
self.stride = twice(stride)
self.group = group
self.fake = fake
self.freeze_bn = freeze_bn
self.momentum = momentum
self.quant_delay = quant_delay
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
if group > 1:
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant')
self.is_depthwise = group > 1
channel_multiplier = out_channels // in_channels
self.conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
pad=padding)
if weight_init is None:
weight_init = initializer('normal', [channel_multiplier, in_channels, *kernel_size])
self.weight = Parameter(weight_init, name='weight')
if gamma_init is None:
gamma_init = initializer('ones', [out_channels])
self.gamma = Parameter(gamma_init, name='gamma')
if beta_init is None:
beta_init = initializer('zeros', [out_channels])
self.beta = Parameter(beta_init, name='beta')
if mean_init is None:
mean_init = initializer('zeros', [out_channels])
self.moving_mean = Parameter(
mean_init, name='moving_mean', requires_grad=False)
if var_init is None:
var_init = initializer('ones', [out_channels])
self.moving_variance = Parameter(
var_init, name='moving_variance', requires_grad=False)
self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul(self.is_depthwise)
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
elif context.get_context('device_target') == "GPU":
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
else:
raise ValueError("Not support platform.")
self.one = Tensor(1, mstype.int32)
self.assignadd = P.AssignAdd()
self.is_gpu = context.get_context('device_target') == "GPU"
def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
return s
def construct(self, x):
out_conv = self.conv(x, self.weight)
# BN fold1
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
self.moving_mean,
self.moving_variance,
self.step)
# fake weight
weight = self.correct_mul(self.weight, self.gamma, running_std)
if self.fake:
weight = self.fake_quant_weight(weight)
out = self.conv(x, weight)
# BN fold2
if self.is_gpu:
if self.training:
out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step)
else:
if self.training:
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
return out
class Conv2dBatchNormQuant(Cell): class Conv2dBatchNormQuant(Cell):
r""" r"""
2D convolution with BatchNormal op folded layer. 2D convolution with BatchNormal op folded layer.
@ -215,6 +589,7 @@ class Conv2dBatchNormQuant(Cell):
per_channel=False, per_channel=False,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
"""init Conv2dBatchNormQuant layer"""
super(Conv2dBatchNormQuant, self).__init__() super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -231,7 +606,6 @@ class Conv2dBatchNormQuant(Cell):
self.kernel_size = (kernel_size, kernel_size) self.kernel_size = (kernel_size, kernel_size)
else: else:
self.kernel_size = kernel_size self.kernel_size = kernel_size
if weight_init is None: if weight_init is None:
weight_init = initializer( weight_init = initializer(
'normal', [out_channels, in_channels // group, *self.kernel_size]) 'normal', [out_channels, in_channels // group, *self.kernel_size])
@ -254,14 +628,6 @@ class Conv2dBatchNormQuant(Cell):
self.step = Parameter(initializer( self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
@ -271,23 +637,29 @@ class Conv2dBatchNormQuant(Cell):
out_channels=out_channels, out_channels=out_channels,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range)
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
momentum=momentum, self.conv = P.Conv2D(out_channel=out_channels,
is_training=True, kernel_size=kernel_size,
freeze_bn=freeze_bn) mode=1,
self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps, pad_mode=pad_mode,
momentum=momentum, pad=padding,
is_training=False, stride=stride,
freeze_bn=freeze_bn) dilation=1,
group=group)
self.correct_mul = P.CorrectionMul() self.correct_mul = P.CorrectionMul()
self.relu = P.ReLU() if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn) self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
elif context.get_context('device_target') == "GPU":
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
else:
raise ValueError("Not support platform.")
self.one = Tensor(1, mstype.int32) self.one = Tensor(1, mstype.int32)
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
def extend_repr(self): def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride, self.in_channels, self.out_channels, self.kernel_size, self.stride,
@ -296,34 +668,32 @@ class Conv2dBatchNormQuant(Cell):
return s return s
def construct(self, x): def construct(self, x):
if self.training: out_conv = self.conv(x, self.weight)
beta = self.beta # BN fold1
gamma = self.gamma batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
gmean = self.moving_mean self.moving_mean,
gvar = self.moving_variance self.moving_variance,
step = self.step self.step)
out_conv = self.conv(x, self.weight) # fake weight
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train( weight = self.correct_mul(self.weight, self.gamma, running_std)
out_conv, gmean, gvar, step) if self.fake:
# BN fold1 weight = self.fake_quant_weight(weight)
weight = self.correct_mul(self.weight, gamma, running_std) out = self.conv(x, weight)
if self.fake: # BN fold2
weight = self.fake_quant_weight(weight) if self.is_gpu:
out = self.conv(x, weight) if self.training:
# BN fold2 out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
out = self.batchnorm_fold2( batch_std, batch_mean, running_std, running_mean, self.step)
out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step) F.control_depend(out, self.assignadd(self.step, self.one))
F.control_depend(out, self.assignadd(self.step, self.one)) else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step)
else: else:
step = self.step if self.training:
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
x, self.moving_mean, self.moving_variance, step) F.control_depend(out, self.assignadd(self.step, self.one))
weight = self.correct_mul(self.weight, self.gamma, running_std) else:
if self.fake: out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
weight = self.fake_quant_weight(weight)
out = self.conv(x, weight)
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean,
running_std, running_mean, step)
return out return out
@ -434,7 +804,7 @@ class Conv2dQuant(Cell):
return out return out
def extend_repr(self): def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format( 'has_bias={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride, self.in_channels, self.out_channels, self.kernel_size, self.stride,

View File

@ -22,7 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.FakeQuantWithMinMax) @bprop_getters.register(P.FakeQuantWithMinMax)
def get_bprop_fakequant_with_minmax(self): def get_bprop_fakequant_with_minmax(self):
"""Generate bprop for FakeQuantWithMinMax""" """Generate bprop for FakeQuantWithMinMax for GPU and Ascend"""
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
@ -34,7 +34,7 @@ def get_bprop_fakequant_with_minmax(self):
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) @bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
def get_bprop_fakequant_with_minmax_perchannel(self): def get_bprop_fakequant_with_minmax_perchannel(self):
"""Generate bprop for FakeQuantWithMinMaxPerChannel""" """Generate bprop for FakeQuantWithMinMaxPerChannel for GPU"""
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
@ -46,7 +46,7 @@ def get_bprop_fakequant_with_minmax_perchannel(self):
@bprop_getters.register(P.BatchNormFold) @bprop_getters.register(P.BatchNormFold)
def get_bprop_batchnorm_fold(self): def get_bprop_batchnorm_fold(self):
"""Generate bprop for BatchNormFold""" """Generate bprop for BatchNormFold for GPU"""
op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn) op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn)
def bprop(x, mean, variance, global_step, out, dout): def bprop(x, mean, variance, global_step, out, dout):
@ -58,8 +58,8 @@ def get_bprop_batchnorm_fold(self):
@bprop_getters.register(P.CorrectionMul) @bprop_getters.register(P.CorrectionMul)
def get_bprop_correction_mul(self): def get_bprop_correction_mul(self):
"""Generate bprop for CorrectionMul""" """Generate bprop for CorrectionMul for Ascend and GPU"""
grad = P.CorrectionMulGrad() grad = P.CorrectionMulGrad(self.channel_axis)
def bprop(x, batch_std, running_std, out, dout): def bprop(x, batch_std, running_std, out, dout):
dx, d_batch_std = grad(dout, x, batch_std, running_std) dx, d_batch_std = grad(dout, x, batch_std, running_std)
@ -70,7 +70,7 @@ def get_bprop_correction_mul(self):
@bprop_getters.register(P.BatchNormFold2) @bprop_getters.register(P.BatchNormFold2)
def get_bprop_batchnorm_fold2(self): def get_bprop_batchnorm_fold2(self):
"""Generate bprop for CorrectionAdd""" """Generate bprop for BatchNormFold2 for GPU"""
op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn) op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn)
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout): def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout):
@ -80,3 +80,48 @@ def get_bprop_batchnorm_fold2(self):
zeros_like(global_step) zeros_like(global_step)
return bprop return bprop
@bprop_getters.register(P.BatchNormFoldD)
def get_bprop_BatchNormFold(self):
"""Generate bprop for BatchNormFold for Ascend"""
op = P.BatchNormFoldGrad_(self.epsilon, self.is_training, self.freeze_bn)
def bprop(x, x_sum, x_square_sum, mean, variance, out, dout):
dx = op(dout[1], dout[2], x, out[1], out[2])
return dx, zeros_like(x_sum), zeros_like(x_square_sum), zeros_like(mean), zeros_like(variance)
return bprop
@bprop_getters.register(P.BNTrainingReduce)
def get_bprop_BNTrainingReduce(self):
def bprop(x, out, dout):
return (zeros_like(x),)
return bprop
@bprop_getters.register(P.BatchNormFold2_D)
def get_bprop_batchnorm_fold2_(self):
"""Generate bprop for BatchNormFold2 for Ascend"""
op_reduce = P.BatchNormFold2GradReduce(freeze_bn=self.freeze_bn)
op_f = P.BatchNormFold2GradD(freeze_bn=self.freeze_bn)
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, out, dout):
dout_reduce, dout_x_reduce = op_reduce(dout, x)
d_batch_std, d_batch_mean, d_gamma, d_x = op_f(dout, dout_reduce, dout_x_reduce, gamma, batch_std,
batch_mean, running_std)
return d_x, dout_reduce, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std)
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxUpdate)
def get_bprop_fakequant_with_minmax_update(self):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
return bprop

View File

@ -0,0 +1,149 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tvm
from topi import generic
from topi.cce import util
batch_norm_op_info = TBERegOp("BatchNormFoldD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batchnorm_fold.so") \
.compute_cost(10) \
.kernel_name("batchnorm_fold") \
.partial_flag(True) \
.attr("momentum", "optional", "float", "all") \
.attr("epsilon", "optional", "float", "all") \
.attr("is_training", "optional", "bool", "all") \
.attr("freeze_bn", "optional", "int", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "x_sum", False, "required", "all") \
.input(2, "x_square_sum", False, "required", "all") \
.input(3, "mean", False, "required", "all") \
.input(4, "variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(1, "batch_mean", False, "required", "all") \
.output(2, "batch_std", False, "required", "all") \
.output(3, "running_mean", False, "required", "all") \
.output(4, "running_std", False, "required", "all") \
.output(5, "mean_updated", False, "required", "all") \
.output(6, "variance_updated", False, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(batch_norm_op_info)
def _batchnorm_fold_tbe():
"""_BatchNormFold TBE register"""
return
@util.check_input_type(dict, dict, dict, dict, dict,
dict, dict, dict, dict, dict, dict, dict,
float, float, bool, int, str, str)
def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated,
momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW",
kernel_name="batchnorm_fold"):
"""batchnorm_fold TBE op"""
momentum = 1.0 - momentum
util.check_kernel_name(kernel_name)
data_format = data_format.upper()
if data_format != "NCHW":
raise RuntimeError("The data_format only support NCHW")
shape_x = x.get("shape")
shape_mean = mean.get("shape")
shape_variance = variance.get("shape")
dtype_x = x.get("dtype")
dtype_mean = mean.get("dtype")
dtype_variance = variance.get("dtype")
for shape in (shape_x, shape_mean, shape_variance):
util.check_shape_rule(shape)
util.check_tensor_shape_size(shape)
check_tuple = ("float16", "float32")
for dtype in (dtype_x, dtype_mean, dtype_variance):
util.check_dtype_rule(dtype.lower(), check_tuple)
format_data = x.get("format").upper()
if format_data not in ("NCHW", "NC1HWC0"):
raise RuntimeError("Format of input only support 4D and 5HD")
if format_data == "NC1HWC0":
if len(shape_x) != 5:
raise RuntimeError("batchnorm_fold only support shape 5D"
"when input format is NC1HWC0")
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
elif format_data == "NCHW":
if len(shape_x) < 2 or len(shape_x) > 4:
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
if shape_x[1] != shape_mean[0]:
raise RuntimeError("data_format is NCHW, shape_bias must"
"be equal to the second axis of shape_x")
shape_mean = (1, shape_x[1],)
for _ in range(2, len(shape_x)):
shape_mean = shape_mean + (1,)
x_input = tvm.placeholder(shape_x, name="x_input", dtype=dtype_x.lower())
x_sum = tvm.placeholder(shape_mean, name="x_sum", dtype=dtype_x.lower())
x_square_sum = tvm.placeholder(shape_mean, name="x_square_sum", dtype=dtype_x.lower())
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
# compute the mean of x
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
# compute the variance of x
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
if num == 1:
batch_var_scaler = 0.0
else:
batch_var_scaler = float(num) / (num - 1)
batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon))
factor = 1.0 - momentum
factor_reverse = momentum
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_variance, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
y = te.lang.cce.vadds(x_input, 0.0)
running_mean = te.lang.cce.vadds(mean, 0.0)
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
with tvm.target.cce():
sch = generic.auto_schedule(res)
config = {"name": kernel_name,
"tensor_list": [x_input, x_sum, x_square_sum, mean, variance] + res}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,110 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold2 op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("batchnorm_fold2.so") \
.compute_cost(10) \
.kernel_name("batchnorm_fold2") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", None, "required", None) \
.input(1, "beta", None, "required", None) \
.input(2, "gamma", None, "required", None) \
.input(3, "batch_std", None, "required", None) \
.input(4, "batch_mean", None, "required", None) \
.input(5, "running_std", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(batchnorm_fold2_op_info)
def _batchnorm_fold2_tbe():
"""_BatchNormFold2 TBE register"""
return
@fusion_manager.register("batchnorm_fold2")
def batchnorm_fold2_compute(x, beta, gamma, batch_std, batch_mean, running_std, kernel_name="batchnorm_fold2"):
"""_BatchNormFold2 compute"""
shape_x = te.lang.cce.util.shape_to_list(x.shape)
factor = te.lang.cce.vdiv(running_std, batch_std)
factor_b = te.lang.cce.broadcast(factor, shape_x)
res = te.lang.cce.vmul(x, factor_b)
bias = te.lang.cce.vdiv(batch_mean, batch_std)
bias = te.lang.cce.vmul(bias, gamma)
bias = te.lang.cce.vsub(beta, bias)
bias_b = te.lang.cce.broadcast(bias, shape_x)
res = te.lang.cce.vadd(res, bias_b)
return res
@util.check_input_type(dict, dict, dict, dict, dict, dict, dict, str)
def batchnorm_fold2(x, beta, gamma, batch_std, batch_mean, running_std, y, kernel_name="batchnorm_fold2"):
"""_BatchNormFold2 op"""
shape = x.get("shape")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in check_list:
raise RuntimeError("Dtype of input only support float16, float32")
data_format = x.get("format")
ori_format = x.get("ori_format")
if data_format.upper() not in ("NC1HWC0", "NCHW"):
raise RuntimeError("Un supported data format {}".format(data_format))
if data_format.upper() == "NCHW" and ori_format != "NCHW":
raise RuntimeError("data_format(NCHW) must same as ori_format")
shape_c = gamma.get("shape")
if gamma.get("format").upper() == "NCHW":
shape_c = 1, gamma.get("shape")[0], 1, 1
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
beta_t = tvm.placeholder(shape_c, name="beta", dtype=inp_dtype)
gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype)
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype)
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
res = batchnorm_fold2_compute(x_t, beta_t, gamma_t, batch_std_t, batch_mean_t,
running_std_t, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": [x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, running_std_t, res]}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,126 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold2Grad op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batchnorm_fold2_grad.so") \
.compute_cost(10) \
.kernel_name("batchnorm_fold2_grad") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \
.input(1, "dout_reduce", None, "required", None) \
.input(2, "dout_x_reduce", None, "required", None) \
.input(3, "gamma", None, "required", None) \
.input(4, "batch_std", None, "required", None) \
.input(5, "batch_mean", None, "required", None) \
.input(6, "running_std", None, "required", None) \
.output(0, "d_batch_std", True, "required", "all") \
.output(1, "d_batch_mean", True, "required", "all") \
.output(2, "d_gamma", True, "required", "all") \
.output(3, "dx", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()
@op_info_register(batchnorm_fold2_grad_op_info)
def _batchnorm_fold2_grad_tbe():
"""_BatchNormFold2Grad TBE register"""
return
@fusion_manager.register("batchnorm_fold2_grad")
def batchnorm_fold2_grad_compute(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std,
kernel_name="batchnorm_fold2_grad"):
"""_BatchNormFold2Grad"""
shape_x = te.lang.cce.util.shape_to_list(dout.shape)
d_batch_std_1 = te.lang.cce.vmul(dout_reduce, batch_mean)
d_batch_std_1 = te.lang.cce.vmul(d_batch_std_1, gamma)
d_batch_std_2 = te.lang.cce.vmul(dout_x_reduce, running_std)
d_batch_std = te.lang.cce.vsub(d_batch_std_1, d_batch_std_2)
d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std)
d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std)
d_batch_mean = te.lang.cce.vmul(dout_reduce, gamma)
d_batch_mean = te.lang.cce.vdiv(d_batch_mean, batch_std)
d_batch_mean = te.lang.cce.vmuls(d_batch_mean, -1.)
d_gamma = te.lang.cce.vmul(dout_reduce, batch_mean)
d_gamma = te.lang.cce.vdiv(d_gamma, batch_std)
d_gamma = te.lang.cce.vmuls(d_gamma, -1.)
dx = te.lang.cce.vdiv(running_std, batch_std)
dx = te.lang.cce.broadcast(dx, shape_x)
dx = te.lang.cce.vmul(dx, dout)
return [d_batch_std, d_batch_mean, d_gamma, dx]
@util.check_input_type(dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, str)
def batchnorm_fold2_grad(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std, d_batch_std,
d_batch_mean, d_gamma, dx, kernel_name="batchnorm_fold2_grad"):
"""_BatchNormFold2Grad op """
shape = dout.get("shape")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = dout.get("dtype").lower()
if not inp_dtype in check_list:
raise RuntimeError("Dtype of input only support float16, float32")
data_format = dout.get("format")
ori_format = dout.get("ori_format")
if data_format.upper() not in ("NC1HWC0", "NCHW"):
raise RuntimeError("Un supported data format {}".format(data_format))
if data_format.upper() == "NCHW" and ori_format != "NCHW":
raise RuntimeError("data_format(NCHW) must same as ori_format")
shape_c = gamma.get("shape")
if gamma.get("format").upper() == "NCHW":
shape_c = 1, gamma.get("shape")[0], 1, 1
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
dout_reduce_t = tvm.placeholder(shape_c, name="dout_reduce", dtype=inp_dtype)
dout_x_reduce_t = tvm.placeholder(shape_c, name="dout_x_reduce", dtype=inp_dtype)
gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype)
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype)
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
res_list = batchnorm_fold2_grad_compute(dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t,
running_std_t, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t, running_std_t] + list(
res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold2GradReduce op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from te.platform.cce_build import build_config
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("batchnorm_fold2_grad_reduce.so") \
.compute_cost(10) \
.kernel_name("batchnorm_fold2_grad_reduce") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.output(0, "dout_reduce", True, "required", "all") \
.output(1, "dout_x_reduce", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(batchnorm_fold2_grad_reduce_op_info)
def _batchnorm_fold2_grad_reduce_tbe():
"""_BatchNormFold2GradReduce TBE register"""
return
@fusion_manager.register("batchnorm_fold2_grad_reduce")
def batchnorm_fold2_grad_reduce_compute(dout, x, dout_args, kernel_name="batchnorm_fold2_grad_reduce"):
"""_BatchNormFold2GradReduce compute"""
dtype = dout_args.get("dtype")
dout_format = dout_args.get("format")
ori_format = dout_args.get("ori_format")
shape = dout_args.get("shape")
if dtype == "float16":
dout = te.lang.cce.cast_to(dout, "float32")
x = te.lang.cce.cast_to(x, "float32")
dout_x = te.lang.cce.vmul(dout, x)
if dout_format == "NC1HWC0":
axis = [0, 2, 3]
dout_reduce, dout_x_reduce = te.lang.cce.tuple_sum([dout, dout_x], axis, True)
else:
axis = list(range(len(shape)))
if ori_format == "NCHW":
axis.pop(1)
for _, i in enumerate(range(len(shape))):
if shape[i] == 1 and i in axis:
axis.remove(i)
dout_reduce = te.lang.cce.sum(dout, axis, False)
dout_x_reduce = te.lang.cce.sum(dout_x, axis, False)
return [dout_reduce, dout_x_reduce]
@util.check_input_type(dict, dict, dict, dict, str)
def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name="batchnorm_fold2_grad_reduce"):
"""_BatchNormFold2GradReduce op"""
shape = x.get("shape")
x_format = x.get("format")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in check_list:
raise RuntimeError("Dtype of input only support float16, float32")
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
res_list = batchnorm_fold2_grad_reduce_compute(dout_t, x_t, dout, kernel_name)
if x_format == "NC1HWC0":
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [dout_t, x_t] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
return
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
with build_config:
tvm.build(sch, tensor_list, "cce", name=kernel_name)

View File

@ -0,0 +1,124 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFoldGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
import te.lang.cce
from te import tvm
from topi import generic
from topi.cce import util
batch_norm_op_info = TBERegOp("BatchNormFoldGradD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batchnorm_fold_grad.so") \
.compute_cost(10) \
.kernel_name("batchnorm_fold_grad") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.attr("is_training", "optional", "bool", "all") \
.attr("freeze_bn", "optional", "int", "all") \
.input(0, "d_batch_mean", False, "required", "all") \
.input(1, "d_batch_std", False, "required", "all") \
.input(2, "x", False, "required", "all") \
.input(3, "batch_mean", False, "required", "all") \
.input(4, "batch_std", False, "required", "all") \
.output(0, "dx", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(batch_norm_op_info)
def _batchnorm_fold_grad_tbe():
"""_BatchNormFoldGrad TBE register"""
return
def _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std):
"""_batchnorm_fold_grad_compute """
shape_x = te.lang.cce.util.shape_to_list(data_x.shape)
normal_size = shape_x[0] * shape_x[2] * shape_x[3]
d_batch_mean_broad = te.lang.cce.broadcast(d_batch_mean, shape_x)
d_batch_std_broad = te.lang.cce.broadcast(d_batch_std, shape_x)
batch_mean_broad = te.lang.cce.broadcast(batch_mean, shape_x)
batch_std_broad = te.lang.cce.broadcast(batch_std, shape_x)
dx = te.lang.cce.vsub(data_x, batch_mean_broad)
dx = te.lang.cce.vmul(dx, d_batch_std_broad)
dx = te.lang.cce.vdiv(dx, batch_std_broad)
dx = te.lang.cce.vadd(dx, d_batch_mean_broad)
dx = te.lang.cce.vmuls(dx, tvm.const(1. / normal_size, dtype=dx.dtype))
return [dx]
@util.check_input_type(dict, dict, dict, dict, dict, dict,
float, bool, int, str)
def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
epsilon=1e-5, is_training=True, freeze_bn=0, kernel_name="batchnorm_fold_grad"):
"""batchnorm_fold_grad op """
util.check_kernel_name(kernel_name)
for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std):
util.check_shape_rule(iv.get("shape"))
util.check_tensor_shape_size(iv.get("shape"))
check_tuple = ("float16", "float32")
for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std):
util.check_dtype_rule(iv.get("dtype").lower(), check_tuple)
shape_x = x.get("shape")
dtype_x = x.get("dtype")
format_data = x.get("format").upper()
if format_data not in ("NCHW", "NC1HWC0"):
raise RuntimeError("Format of input only support 4D and 5HD")
shape_mean = d_batch_mean.get("shape")
dtype_mean = d_batch_mean.get("dtype").lower()
if format_data == "NC1HWC0":
if len(shape_x) != 5:
raise RuntimeError("batchnorm_fold only support shape 5D"
"when input format is NC1HWC0")
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
elif format_data == "NCHW":
if len(shape_x) < 2 or len(shape_x) > 4:
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
if shape_x[1] != shape_mean[0]:
raise RuntimeError("data_format is NCHW, shape_bias must"
"be equal to the second axis of shape_x")
shape_mean = (1, shape_x[1],)
for _ in range(2, len(shape_x)):
shape_mean = shape_mean + (1,)
d_batch_mean = tvm.placeholder(shape_mean, name="d_batch_mean", dtype=dtype_mean)
d_batch_std = tvm.placeholder(shape_mean, name="d_batch_std", dtype=dtype_mean)
data_x = tvm.placeholder(shape_x, name="data_x", dtype=dtype_x.lower())
batch_mean = tvm.placeholder(shape_mean, name="batch_mean", dtype=dtype_mean)
batch_std = tvm.placeholder(shape_mean, name="batch_std", dtype=dtype_mean)
res = _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [d_batch_mean, d_batch_std, data_x, batch_mean, batch_std] + res
config = {"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CorrectionMul op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
correction_mul_op_info = TBERegOp("CorrectionMul") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("correction_mul.so") \
.compute_cost(10) \
.kernel_name("correction_mul") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "batch_std", None, "required", None) \
.input(2, "running_std", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(correction_mul_op_info)
def _correction_mul_tbe():
"""CorrectionMul TBE register"""
return
@fusion_manager.register("correction_mul")
def correction_mul_compute(x, batch_std, running_std, kernel_name="correction_mul"):
"""CorrectionMul compute"""
shape_x = te.lang.cce.util.shape_to_list(x.shape)
factor = te.lang.cce.vdiv(batch_std, running_std)
factor_b = te.lang.cce.broadcast(factor, shape_x)
res = te.lang.cce.vmul(x, factor_b)
return res
@util.check_input_type(dict, dict, dict, dict, int, str)
def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correction_mul"):
"""CorrectionMul op"""
shape = x.get("shape")
data_format = x.get("format")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in check_list:
raise RuntimeError("Dtype of input only support float16, float32")
# shape = util.shape_refine(shape)
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
shape_c = [1] * len(shape)
shape_c[channel] = batch_std.get("ori_shape")[0]
if data_format == "NC1HWC0" and channel == 1:
shape_c = batch_std.get("shape")
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
res = correction_mul_compute(x_t, batch_std_t, running_std_t, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": [x_t, batch_std_t, running_std_t, res]}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,134 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CorrectionMul op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("correction_mul_grad.so") \
.compute_cost(10) \
.kernel_name("correction_mul_grad") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "batch_std", None, "required", None) \
.input(3, "running_std", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.output(1, "d_batch_std", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(correction_mul_grad_op_info)
def _correction_mul_grad_tbe():
"""CorrectionMulGrad TBE register"""
return
@fusion_manager.register("correction_mul_grad")
def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_format, kernel_name="correction_mul"):
"""CorrectionMulGrad compute"""
shape_x = te.lang.cce.util.shape_to_list(x.shape)
factor = te.lang.cce.vdiv(batch_std, running_std)
factor_b = te.lang.cce.broadcast(factor, shape_x)
dx = te.lang.cce.vmul(dout, factor_b)
mul_data = te.lang.cce.vmul(dout, x)
if channel == 0:
if data_format == "NCHW":
axis = [1, 2, 3]
else:
axis = [1, 2, 3, 4]
else:
axis = [2, 3]
red_data = te.lang.cce.sum(mul_data, axis, keepdims=True)
d_batch_std = te.lang.cce.vdiv(red_data, running_std)
return [dx, d_batch_std]
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"):
"""CorrectionMulGrad op"""
shape_dout = dout.get("shape")
shape_x = dout.get("shape")
dtype_dout = dout.get("dtype")
dtype_x = x.get("dtype")
dtype_batch_std = batch_std.get("dtype")
dtype_running_std = running_std.get("dtype")
inp_dtype_dout = dtype_dout.lower()
inp_dtype_x = dtype_x.lower()
inp_dtype_batch_std = dtype_batch_std.lower()
inp_dtype_running_std = dtype_running_std.lower()
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_x, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_batch_std, ("float32",))
util.check_dtype_rule(inp_dtype_running_std, ("float32",))
util.compare_tensor_dict_key(dout, x, "dtype")
util.compare_tensor_dict_key(dout, x, "shape")
util.compare_tensor_dict_key(dx, x, "shape")
util.compare_tensor_dict_key(batch_std, running_std, "shape")
util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_x)
util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
data_format = dout.get("format")
ori_format = dout.get("format")
if data_format.upper() not in ("NC1HWC0", "NCHW"):
raise RuntimeError("Un supported data format {}".format(data_format))
if data_format.upper() == "NCHW" and ori_format != "NCHW":
raise RuntimeError("data_format(NCHW) must same as ori_format")
shape_c = [1] * len(shape_x)
shape_c[channel] = batch_std.get("ori_shape")[0]
if data_format == "NC1HWC0" and channel == 1:
shape_c = batch_std.get("shape")
dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
x_t = tvm.placeholder(shape_x, name="x", dtype=inp_dtype_x)
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype_batch_std)
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype_running_std)
res_list = correction_mul_grad_compute(dout_t, x_t, batch_std_t, running_std_t, channel, data_format, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,146 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMax op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_vars_ema.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_vars_ema") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_op_info)
def _fake_quant_tbe():
"""FakeQuantWithMinMax TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_vars_ema")
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="correction_mul"):
"""FakeQuantWithMinMax"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# boradcast to shape
nudge_min = te.lang.cce.broadcast(nudge_min, shape, x.dtype)
nudge_max = te.lang.cce.broadcast(nudge_max, shape, x.dtype)
scale = te.lang.cce.broadcast(scale, shape, x.dtype)
# FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale),
0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant"):
"""FakeQuantWithMinMax"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
min_shape = util.scalar2tensor_one(min_shape)
max_shape = util.scalar2tensor_one(max_shape)
util.check_kernel_name(kernel_name)
util.check_shape_rule(input_shape)
util.check_shape_rule(min_shape, 1, 1, 1)
util.check_shape_rule(max_shape, 1, 1, 1)
util.check_tensor_shape_size(input_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = input_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,156 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxGrad op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_grad") \
.partial_flag(True) \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "min", None, "required", None) \
.input(3, "max", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
def _less_compare_float32(data_x, data_y):
"""_less_compare_float32 compute"""
shape_inputs = te.lang.cce.util.shape_to_list(data_x.shape)
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
res_max = te.lang.cce.vmax(res_min, data_zero)
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
return res
@op_info_register(fake_quant_grad_op_info)
def _fake_quant_grad_tbe():
"""FakeQuantWithMinMaxGrad TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_grad")
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, shape_min)
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
nudge_min = te.lang.cce.broadcast(nudge_min, shape)
nudge_max = te.lang.cce.broadcast(nudge_max, shape)
bool_over_min = _less_compare_float32(nudge_min, x)
bool_less_max = _less_compare_float32(x, nudge_max)
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
res = te.lang.cce.vmul(dout, bool_between)
return res
@util.check_input_type(dict, dict, dict, dict, dict, int, int, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
min_shape = util.scalar2tensor_one(min_shape)
max_shape = util.scalar2tensor_one(max_shape)
util.check_kernel_name(kernel_name)
util.check_shape_rule(input_shape)
util.check_shape_rule(min_shape, 1, 1, 1)
util.check_shape_rule(max_shape, 1, 1, 1)
util.check_tensor_shape_size(input_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", 'float16']
x_dtype = input_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
quant_min = 0
quant_max = 2 ** num_bits - 1
dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [dout_data, input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,137 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxUpdate op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_update5d.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "min_up", True, "required", "all") \
.output(1, "max_up", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_update5d_op_info)
def _fake_quant_update5d_tbe():
"""_FakeQuantWithMinMaxUpdate5D TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_update")
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMaxUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMax op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
min_shape = util.scalar2tensor_one(min_shape)
max_shape = util.scalar2tensor_one(max_shape)
util.check_kernel_name(kernel_name)
util.check_shape_rule(input_shape)
util.check_shape_rule(min_shape, 1, 1, 1)
util.check_shape_rule(max_shape, 1, 1, 1)
util.check_tensor_shape_size(input_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = input_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [input_data, min_data, max_data] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -30,6 +30,10 @@ __all__ = ["FakeQuantWithMinMax",
"CorrectionMulGrad", "CorrectionMulGrad",
"BatchNormFold2", "BatchNormFold2",
"BatchNormFold2Grad", "BatchNormFold2Grad",
"BatchNormFoldD",
"BNTrainingReduce",
"BatchNormFold2_D",
"FakeQuantWithMinMaxUpdate",
] ]
@ -166,7 +170,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max) >>> result = fake_quant(input_x, _min, _max)
""" """
support_quant_bit = [4, 8] support_quant_bit = [4, 8]
channel_idx = 0 channel_axis = 0
@prim_attr_register @prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
@ -188,8 +192,8 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape return x_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@ -272,7 +276,7 @@ class BatchNormFold(PrimitiveWithInfer):
>>> global_step = Tensor(np.arange(6), mindspore.int32) >>> global_step = Tensor(np.arange(6), mindspore.int32)
>>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step) >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
""" """
channel = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
@ -287,7 +291,7 @@ class BatchNormFold(PrimitiveWithInfer):
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name) validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
return mean_shape, mean_shape, mean_shape, mean_shape return mean_shape, mean_shape, mean_shape, mean_shape
@ -314,7 +318,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
>>> global_step = Tensor([2], mindspore.int32) >>> global_step = Tensor([2], mindspore.int32)
>>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step) >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
""" """
channel = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
@ -333,8 +337,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"batch_mean shape", batch_mean_shape, Rel.EQ, self.name) "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
validator.check("d_batch_mean shape", d_batch_mean_shape, validator.check("d_batch_mean shape", d_batch_mean_shape,
"batch_std shape", batch_std_shape, Rel.EQ, self.name) "batch_std shape", batch_std_shape, Rel.EQ, self.name)
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
self.name) "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
return x_shape return x_shape
@ -368,17 +372,17 @@ class CorrectionMul(PrimitiveWithInfer):
>>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32) >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
>>> out = correction_mul(input_x, batch_std, running_std) >>> out = correction_mul(input_x, batch_std, running_std)
""" """
channel = 0
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self, channel_axis=0):
"""init correction mul layer""" """init correction mul layer"""
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
outputs=['out']) outputs=['out'])
def infer_shape(self, x_shape, batch_std_shape, running_std_shape): def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
return x_shape return x_shape
@ -400,20 +404,20 @@ class CorrectionMulGrad(PrimitiveWithInfer):
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
>>> result = correction_mul_grad(dout, input_x, gamma, running_std) >>> result = correction_mul_grad(dout, input_x, gamma, running_std)
""" """
channel = 0
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self, channel_axis=0):
"""init correction mul layer""" """init correction mul layer"""
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
outputs=['dx', 'd_gamma']) outputs=['dx', 'd_gamma'])
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel], validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check("running_std_shape[0]", running_std_shape[0],
"dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape, gamma_shape return x_shape, gamma_shape
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
@ -454,7 +458,7 @@ class BatchNormFold2(PrimitiveWithInfer):
>>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean, >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
>>> running_std, running_mean, global_step) >>> running_std, running_mean, global_step)
""" """
channel = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, freeze_bn=0): def __init__(self, freeze_bn=0):
@ -471,7 +475,7 @@ class BatchNormFold2(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
return x_shape return x_shape
@ -501,7 +505,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
>>> global_step = Tensor(np.array([-2]), mindspore.int32) >>> global_step = Tensor(np.array([-2]), mindspore.int32)
>>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step) >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
""" """
channel = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, freeze_bn=0): def __init__(self, freeze_bn=0):
@ -519,7 +523,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel], validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
@ -542,3 +546,259 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
class BatchNormFoldD(PrimitiveWithInfer):
"""Performs grad of _BatchNormFold operation."""
@prim_attr_register
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init _BatchNormFold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
self.data_format = "NCHW"
self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
'mean_updated', 'variance_updated'])
def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return x_type, x_type, x_type, x_type, x_type, x_type, x_type
class BatchNormFoldGradD(PrimitiveWithInfer):
"""Performs grad of _BatchNormFoldGrad operation."""
@prim_attr_register
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init _BatchNormFoldGrad layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
outputs=['dx'])
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
return x_shape
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
validator.check("input type", x_type, "batch_mean type", batch_mean_type)
validator.check("input type", x_type, "batch_std type", batch_std_type)
args = {"input type": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return x_type
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **x_sum** (Tensor) - Tensor has the same shape as x.
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
"""
@prim_attr_register
def __init__(self):
"""init _BNTrainingReduce layer"""
self.init_prim_io_names(inputs=['x'],
outputs=['x_sum', 'x_square_sum'])
def infer_shape(self, x_shape):
return [x_shape[1]], [x_shape[1]]
def infer_dtype(self, x_type):
return x_type, x_type
class BatchNormFold2_D(PrimitiveWithInfer):
"""
Scale the bias with a correction factor to the long term statistics
prior to quantization. This ensures that there is no jitter in the quantized bias
due to batch to batch variation.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
- **beta** (Tensor) - Tensor of shape :math:`(C,)`.
- **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **global_step** (Tensor) - Tensor to record current global step.
Outputs:
- **y** (Tensor) - Tensor has the same shape as x.
"""
channel_axis = 1
@prim_attr_register
def __init__(self, freeze_bn=0):
"""init conv2d fold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold2
self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
outputs=['y'])
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return x_type
class BatchNormFold2GradD(PrimitiveWithInfer):
"""Performs grad of CorrectionAddGrad operation."""
channel_axis = 1
@prim_attr_register
def __init__(self, freeze_bn=False):
"""init MulFold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
self.freeze_bn = freeze_bn
self.init_prim_io_names(
inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
batch_mean_shape, running_std_shape):
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name)
return gamma_shape, gamma_shape, gamma_shape, dout_shape
def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
batch_mean_type, running_std_type):
validator.check("batch_std type", batch_std_type,
"batch_mean type", batch_mean_type)
validator.check("batch_std type", batch_std_type,
"gamma type", gamma_type)
validator.check("batch_std type", batch_std_type,
"running_std type", running_std_type)
validator.check("batch_std_type", batch_std_type,
"dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type
class BatchNormFold2GradReduce(PrimitiveWithInfer):
"""Performs grad of CorrectionAddGrad operation."""
channel_axis = 1
@prim_attr_register
def __init__(self, freeze_bn=False):
"""init MulFold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
self.freeze_bn = freeze_bn
self.init_prim_io_names(inputs=['dout', 'x'],
outputs=['dout_reduce', 'dout_x_reduce'])
def infer_shape(self, dout_shape, x_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type
class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMax OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
return min_type, max_type

View File

@ -22,7 +22,7 @@ from mindspore import nn
from mindspore.nn.layer import combined from mindspore.nn.layer import combined
from mindspore.train.quant import quant as qat from mindspore.train.quant import quant as qat
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
@ -64,7 +64,7 @@ class LeNet5(nn.Cell):
x = self.fc3(x) x = self.fc3(x)
return x return x
"""
def test_qat_lenet(): def test_qat_lenet():
net = LeNet5() net = LeNet5()
net = qat.convert_quant_network( net = qat.convert_quant_network(
@ -92,3 +92,4 @@ def test_qat_mobile_train():
net = nn.WithLossCell(net, loss) net = nn.WithLossCell(net, loss)
net = nn.TrainOneStepCell(net, optimizer) net = nn.TrainOneStepCell(net, optimizer)
net(img, label) net(img, label)
"""