forked from mindspore-Ecosystem/mindspore
reset thor warning
This commit is contained in:
parent
2ec087e5bd
commit
cd14e4021c
|
@ -14,3 +14,22 @@
|
|||
# ============================================================================
|
||||
|
||||
"""custom ops"""
|
||||
from .batchnorm_fold import _batchnorm_fold_tbe
|
||||
from .batchnorm_fold2 import _batchnorm_fold2_tbe
|
||||
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
|
||||
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
|
||||
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
|
||||
from .correction_mul import _correction_mul_tbe
|
||||
from .correction_mul_grad import _correction_mul_grad_tbe
|
||||
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
|
||||
from .fake_learned_scale_quant_perlayer_grad import _fake_learned_scale_quant_perlayer_grad_d_tbe
|
||||
from .fake_learned_scale_quant_perlayer_grad_reduce import _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe
|
||||
from .fake_learned_scale_quant_perchannel import _fake_learned_scale_quant_perchannel_tbe
|
||||
from .fake_learned_scale_quant_perchannel_grad import _fake_learned_scale_quant_perchannel_grad_d_tbe
|
||||
from .fake_learned_scale_quant_perchannel_grad_reduce import _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe
|
||||
from .fake_quant_perchannel import _fake_quant_perchannel_tbe
|
||||
from .fake_quant_perchannel_grad import _fake_quant_perchannel_grad_tbe
|
||||
from .fake_quant_perlayer import _fake_quant_per_layer_tbe
|
||||
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
|
||||
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
|
||||
from .minmax_update_perlayer import _minmax_update_perlayer_tbe
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
"""
|
||||
copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License == distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
_basic
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
|
|
@ -28,11 +28,11 @@ batch_norm_op_info = TBERegOp("BatchNormFoldD") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold") \
|
||||
.partial_flag(True) \
|
||||
.attr("momentum", "optional", "float", "all") \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.attr("freeze_bn", "optional", "int", "all") \
|
||||
.attr("format", "optional", "str", "all") \
|
||||
.attr("momentum", "optional", "float", "all", "0.9") \
|
||||
.attr("epsilon", "optional", "float", "all", "0.00001") \
|
||||
.attr("is_training", "optional", "bool", "all", "true") \
|
||||
.attr("freeze_bn", "optional", "int", "all", "0") \
|
||||
.attr("format", "optional", "str", "all", "NCHW") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "x_sum", False, "required", "all") \
|
||||
.input(2, "x_square_sum", False, "required", "all") \
|
||||
|
@ -57,43 +57,6 @@ def _batchnorm_fold_tbe():
|
|||
return
|
||||
|
||||
|
||||
def _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon):
|
||||
"""_batchnorm_fold_compute"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict,
|
||||
dict, dict, dict, dict, dict, dict, dict,
|
||||
float, float, bool, int, str, str)
|
||||
|
@ -145,7 +108,39 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
|||
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
|
||||
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
|
||||
|
||||
res = _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon)
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
config = {"name": kernel_name,
|
||||
|
|
|
@ -21,7 +21,6 @@ from te.platform.cce_build import build_config
|
|||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
|
@ -82,8 +81,9 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
|
|||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in ["float16", "float32"]:
|
||||
if not inp_dtype in check_list:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
@ -100,7 +100,7 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
|
|||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
||||
return
|
||||
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
|
||||
with build_config:
|
||||
tvm.build(sch, tensor_list, "cce", name=kernel_name)
|
||||
|
|
|
@ -95,12 +95,12 @@ def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
|
|||
dtype_mean = d_batch_mean.get("dtype").lower()
|
||||
if format_data == "NC1HWC0":
|
||||
if len(shape_x) != 5:
|
||||
raise RuntimeError("batchnorm_fold grad only support shape 5D"
|
||||
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||
"when input format is NC1HWC0")
|
||||
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||
elif format_data == "NCHW":
|
||||
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||
raise RuntimeError("batchnorm_fold grad only support shape 2D to 4D")
|
||||
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||
if shape_x[1] != shape_mean[0]:
|
||||
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||
"be equal to the second axis of shape_x")
|
||||
|
|
|
@ -66,8 +66,9 @@ def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correctio
|
|||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in ["float16", "float32"]:
|
||||
if not inp_dtype in check_list:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
|
|
@ -86,9 +86,11 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
|
|||
return res
|
||||
|
||||
|
||||
def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""Get and check fake_quant_perchannel parameters"""
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel(x, min_val, max_val, y,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""FakeQuantPerChannel"""
|
||||
x_shape = x.get("shape")
|
||||
x_shape_ = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
|
@ -118,25 +120,15 @@ def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel(x, min_val, max_val, y,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""FakeQuantPerChannel"""
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
x_shape, shape_c, x_dtype = fake_quant_perchannel_param(x, min_val, max_val,
|
||||
channel_axis, kernel_name)
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||
|
|
|
@ -110,9 +110,11 @@ def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, qua
|
|||
return res
|
||||
|
||||
|
||||
def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""Get and check FakeQuantPerChannelGrad parameters"""
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""FakeQuantPerChannelGrad"""
|
||||
x_shape = x.get("shape")
|
||||
x_shape_ = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
|
@ -142,18 +144,6 @@ def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""FakeQuantPerChannelGrad"""
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
|
@ -163,8 +153,10 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
x_shape, shape_c, x_dtype = fake_quant_perchannel_grad_param(x, min_val, max_val,
|
||||
channel_axis, kernel_name)
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
|
|
|
@ -82,6 +82,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
|||
kernel_name="minmax_update_perchannel"):
|
||||
"""MinMaxUpdatePerChannel op"""
|
||||
x_shape = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
x_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
min_dtype = min_val.get("dtype")
|
||||
|
|
Loading…
Reference in New Issue