forked from mindspore-Ecosystem/mindspore
split correction_mul op
This commit is contained in:
parent
c749f513ac
commit
c742384a39
|
@ -18,6 +18,7 @@
|
|||
from .. import operations as P
|
||||
from .grad_base import bprop_getters
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ... import context
|
||||
|
||||
|
||||
@bprop_getters.register(P.FakeQuantPerLayer)
|
||||
|
@ -64,12 +65,21 @@ def get_bprop_batchnorm_fold(self):
|
|||
@bprop_getters.register(P.CorrectionMul)
|
||||
def get_bprop_correction_mul(self):
|
||||
"""Generate bprop for CorrectionMul for Ascend and GPU"""
|
||||
grad = P.CorrectionMulGrad(self.channel_axis)
|
||||
grad_dx = P.CorrectionMulGrad(self.channel_axis)
|
||||
grad_d_batch_std = P.CorrectionMulGradReduce(self.channel_axis)
|
||||
|
||||
def bprop(x, batch_std, running_std, out, dout):
|
||||
dx, d_batch_std = grad(dout, x, batch_std, running_std)
|
||||
dx, d_batch_std = grad_dx(dout, x, batch_std, running_std)
|
||||
return dx, d_batch_std, zeros_like(running_std)
|
||||
|
||||
def bprop_npu(x, batch_std, running_std, out, dout):
|
||||
dx, mul_dx = grad_dx(dout, x, batch_std, running_std)
|
||||
d_batch_std = grad_d_batch_std(mul_dx)
|
||||
return dx, d_batch_std, zeros_like(running_std)
|
||||
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
return bprop_npu
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
|||
.input(2, "batch_std", None, "required", None) \
|
||||
.input(3, "running_std", None, "required", None) \
|
||||
.output(0, "dx", True, "required", "all") \
|
||||
.output(1, "d_batch_std", True, "required", "all") \
|
||||
.output(1, "mul_dx", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
@ -56,21 +56,14 @@ def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_f
|
|||
factor = te.lang.cce.vdiv(batch_std, running_std)
|
||||
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||
dx = te.lang.cce.vmul(dout, factor_b)
|
||||
mul_data = te.lang.cce.vmul(dout, x)
|
||||
if channel == 0:
|
||||
if data_format == "NCHW":
|
||||
axis = [1, 2, 3]
|
||||
else:
|
||||
axis = [1, 2, 3, 4]
|
||||
else:
|
||||
axis = [2, 3]
|
||||
red_data = te.lang.cce.sum(mul_data, axis, keepdims=True)
|
||||
d_batch_std = te.lang.cce.vdiv(red_data, running_std)
|
||||
return [dx, d_batch_std]
|
||||
mul_dx = te.lang.cce.vmul(dout, x)
|
||||
running_std_b = te.lang.cce.broadcast(running_std, shape_x)
|
||||
mul_dx = te.lang.cce.vdiv(mul_dx, running_std_b)
|
||||
return [dx, mul_dx]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
|
||||
def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"):
|
||||
def correction_mul_grad(dout, x, batch_std, running_std, dx, mul_dx, channel, kernel_name="correction_mul_grad"):
|
||||
"""CorrectionMulGrad op"""
|
||||
shape_dout = dout.get("shape")
|
||||
shape_x = dout.get("shape")
|
||||
|
@ -93,7 +86,7 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
|
|||
util.compare_tensor_dict_key(dout, x, "shape")
|
||||
util.compare_tensor_dict_key(dx, x, "shape")
|
||||
util.compare_tensor_dict_key(batch_std, running_std, "shape")
|
||||
util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")
|
||||
util.compare_tensor_dict_key(dx, mul_dx, "shape")
|
||||
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape_x)
|
||||
|
@ -120,7 +113,84 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
|
|||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
|
||||
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
|
||||
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + res_list
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
||||
|
||||
|
||||
correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("correction_mul_grad_reduce.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("correction_mul_grad_reduce") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.output(0, "d_batch_std", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(correction_mul_grad_reduce_op_info)
|
||||
def _correction_mul_grad_reduce_tbe():
|
||||
"""CorrectionMulGradReduce TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("correction_mul_grad_reduce")
|
||||
def correction_mul_grad_reduce_compute(mul_dx, channel, data_format, kernel_name="correction_mul"):
|
||||
"""CorrectionMulGradReduce compute"""
|
||||
if channel == 0:
|
||||
if data_format == "NCHW":
|
||||
axis = [1, 2, 3]
|
||||
else:
|
||||
axis = [1, 2, 3, 4]
|
||||
else:
|
||||
axis = [2, 3]
|
||||
d_batch_std = te.lang.cce.sum(mul_dx, axis, keepdims=True)
|
||||
return d_batch_std
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, int, str)
|
||||
def correction_mul_grad_reduce(mul_dx, d_batch_std, channel, kernel_name="correction_mul_grad_reduce"):
|
||||
"""CorrectionMulGradReduce op"""
|
||||
shape_dout = mul_dx.get("shape")
|
||||
shape_x = mul_dx.get("shape")
|
||||
|
||||
dtype_dout = mul_dx.get("dtype")
|
||||
|
||||
inp_dtype_dout = dtype_dout.lower()
|
||||
|
||||
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
|
||||
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape_x)
|
||||
util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
|
||||
|
||||
data_format = mul_dx.get("format")
|
||||
ori_format = mul_dx.get("format")
|
||||
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||
|
||||
shape_c = [1] * len(shape_x)
|
||||
shape_c[channel] = d_batch_std.get("ori_shape")[0]
|
||||
if data_format == "NC1HWC0" and channel == 1:
|
||||
shape_c = d_batch_std.get("shape")
|
||||
|
||||
dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
|
||||
res = correction_mul_grad_reduce_compute(dout_t, channel, data_format, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
||||
tensor_list = [dout_t, res]
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
|
|
@ -31,10 +31,12 @@ __all__ = ["FakeQuantPerLayer",
|
|||
"BatchNormFoldGrad",
|
||||
"CorrectionMul",
|
||||
"CorrectionMulGrad",
|
||||
"CorrectionMulGradReduce",
|
||||
"BatchNormFold2",
|
||||
"BatchNormFold2Grad",
|
||||
"BatchNormFoldD",
|
||||
"BatchNormFoldGradD",
|
||||
"BNTrainingReduce",
|
||||
"BatchNormFold2_D",
|
||||
"BatchNormFold2GradD",
|
||||
"BatchNormFold2GradReduce",
|
||||
|
@ -332,7 +334,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
Batch normalization folded.
|
||||
|
||||
Args:
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.9.
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||
float32 else 1e-3. Default: 1e-5.
|
||||
is_training (bool): In training mode set True, else set False. Default: True.
|
||||
|
@ -364,7 +366,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
channel_axis = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""init batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
|
@ -499,7 +501,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):
|
|||
from mindspore.ops._op_impl._custom_op import correction_mul_grad
|
||||
self.channel_axis = channel_axis
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
|
||||
outputs=['dx', 'd_gamma'])
|
||||
outputs=['dx', 'mul_dx'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
|
||||
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
|
||||
|
@ -507,12 +509,45 @@ class CorrectionMulGrad(PrimitiveWithInfer):
|
|||
Rel.EQ, self.name)
|
||||
validator.check("running_std_shape[0]", running_std_shape[0],
|
||||
"dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
return x_shape, x_shape
|
||||
return x_shape, gamma_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
|
||||
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
return x_type, x_type
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
return x_type, x_type
|
||||
return x_type, gamma_type
|
||||
|
||||
|
||||
class CorrectionMulGradReduce(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs grad reduce of CorrectionMul operation.
|
||||
|
||||
Examples:
|
||||
>>> correction_mul_grad_rd = P.CorrectionMulGradReduce()
|
||||
>>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
|
||||
>>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
|
||||
>>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
|
||||
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
|
||||
>>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, channel_axis=0):
|
||||
"""init correction mul reduce layer"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import correction_mul_grad
|
||||
self.channel_axis = channel_axis
|
||||
self.init_prim_io_names(inputs=['mul_dx'],
|
||||
outputs=['d_gamma'])
|
||||
|
||||
def infer_shape(self, mul_dx_shape):
|
||||
return [mul_dx_shape[self.channel_axis]]
|
||||
|
||||
def infer_dtype(self, mul_dx_type):
|
||||
return mul_dx_type
|
||||
|
||||
|
||||
class BatchNormFold2(PrimitiveWithInfer):
|
||||
|
@ -696,6 +731,32 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
"""
|
||||
reduce sum at axis [0, 2, 3].
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
|
||||
Outputs:
|
||||
- **x_sum** (Tensor) - Tensor has the same shape as x.
|
||||
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init _BNTrainingReduce layer"""
|
||||
self.init_prim_io_names(inputs=['x'],
|
||||
outputs=['x_sum', 'x_square_sum'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return [x_shape[1]], [x_shape[1]]
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type, x_type
|
||||
|
||||
|
||||
class BatchNormFold2_D(PrimitiveWithInfer):
|
||||
"""
|
||||
Scale the bias with a correction factor to the long term statistics
|
||||
|
|
Loading…
Reference in New Issue