!1791 fix compile bugs in quant.py and correction_mul_grad

Merge pull request !1791 from wandongdong/master
This commit is contained in:
mindspore-ci-bot 2020-06-02 16:21:22 +08:00 committed by Gitee
commit 9c305813b1
4 changed files with 14 additions and 16 deletions

View File

@ -21,7 +21,7 @@
#include "fake_quant_impl.cuh"
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, bool symmetric) {
const float *nudge_max, const float *scale) {
float input_x = 0.f;
int nudge_input = 0;
@ -35,7 +35,7 @@ __global__ void FakeQuantize(const float *input, float *output, const int size,
input_x = nudge_max[0];
}
// clamp shift
nudge_input = floor((input_x - nudge_min[0]) / scale[0] + 0.5f);
nudge_input = round((input_x - nudge_min[0]) / scale[0]);
// quantize
output[i] = nudge_input * scale[0] + nudge_min[0];
@ -99,8 +99,7 @@ __global__ void UpdateInputMinMax(float *input_min, float *input_max, const floa
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
const float *scale, bool symmetric, cudaStream_t cuda_stream) {
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale,
symmetric);
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale);
return;
}

View File

@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Validator as validator, Rel
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
import mindspore.context as context
@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell):
class FakeQuantWithMinMax(Cell):
r"""
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell):
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
narrow_range=False):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell):
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell):
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.channel_axis = int(group > 1)
self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul()
self.correct_mul = P.CorrectionMul(self.channel_axis)
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)

View File

@ -93,8 +93,8 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_x, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_batch_std, ("float32",))
util.check_dtype_rule(inp_dtype_running_std, ("float32",))
util.check_dtype_rule(inp_dtype_batch_std, ("float16", "float32"))
util.check_dtype_rule(inp_dtype_running_std, ("float16", "float32"))
util.compare_tensor_dict_key(dout, x, "dtype")
util.compare_tensor_dict_key(dout, x, "shape")
util.compare_tensor_dict_key(dx, x, "shape")

View File

@ -80,8 +80,7 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale),
0.5))
nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res