reset thor warning

This commit is contained in:
cmy_melody 2021-06-08 10:46:42 +08:00
parent 2ec087e5bd
commit cd14e4021c
9 changed files with 99 additions and 97 deletions

View File

@ -14,3 +14,22 @@
# ============================================================================
"""custom ops"""
from .batchnorm_fold import _batchnorm_fold_tbe
from .batchnorm_fold2 import _batchnorm_fold2_tbe
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
from .correction_mul import _correction_mul_tbe
from .correction_mul_grad import _correction_mul_grad_tbe
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
from .fake_learned_scale_quant_perlayer_grad import _fake_learned_scale_quant_perlayer_grad_d_tbe
from .fake_learned_scale_quant_perlayer_grad_reduce import _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe
from .fake_learned_scale_quant_perchannel import _fake_learned_scale_quant_perchannel_tbe
from .fake_learned_scale_quant_perchannel_grad import _fake_learned_scale_quant_perchannel_grad_d_tbe
from .fake_learned_scale_quant_perchannel_grad_reduce import _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe
from .fake_quant_perchannel import _fake_quant_perchannel_tbe
from .fake_quant_perchannel_grad import _fake_quant_perchannel_grad_tbe
from .fake_quant_perlayer import _fake_quant_per_layer_tbe
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
from .minmax_update_perlayer import _minmax_update_perlayer_tbe

View File

@ -1,18 +1,20 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
_basic
"""
from __future__ import absolute_import

View File

@ -28,11 +28,11 @@ batch_norm_op_info = TBERegOp("BatchNormFoldD") \
.compute_cost(10) \
.kernel_name("batchnorm_fold") \
.partial_flag(True) \
.attr("momentum", "optional", "float", "all") \
.attr("epsilon", "optional", "float", "all") \
.attr("is_training", "optional", "bool", "all") \
.attr("freeze_bn", "optional", "int", "all") \
.attr("format", "optional", "str", "all") \
.attr("momentum", "optional", "float", "all", "0.9") \
.attr("epsilon", "optional", "float", "all", "0.00001") \
.attr("is_training", "optional", "bool", "all", "true") \
.attr("freeze_bn", "optional", "int", "all", "0") \
.attr("format", "optional", "str", "all", "NCHW") \
.input(0, "x", False, "required", "all") \
.input(1, "x_sum", False, "required", "all") \
.input(2, "x_square_sum", False, "required", "all") \
@ -57,43 +57,6 @@ def _batchnorm_fold_tbe():
return
def _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon):
"""_batchnorm_fold_compute"""
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
# compute the mean of x
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
# compute the variance of x
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
if num == 1:
batch_var_scaler = 0.0
else:
batch_var_scaler = float(num) / (num - 1)
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
factor = 1.0 - momentum
factor_reverse = momentum
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
y = te.lang.cce.vadds(x_input, 0.0)
running_mean = te.lang.cce.vadds(mean, 0.0)
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
return res
@util.check_input_type(dict, dict, dict, dict, dict,
dict, dict, dict, dict, dict, dict, dict,
float, float, bool, int, str, str)
@ -145,7 +108,39 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
res = _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon)
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
# compute the mean of x
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
# compute the variance of x
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
if num == 1:
batch_var_scaler = 0.0
else:
batch_var_scaler = float(num) / (num - 1)
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
factor = 1.0 - momentum
factor_reverse = momentum
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
y = te.lang.cce.vadds(x_input, 0.0)
running_mean = te.lang.cce.vadds(mean, 0.0)
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
with tvm.target.cce():
sch = generic.auto_schedule(res)
config = {"name": kernel_name,

View File

@ -21,7 +21,6 @@ from te.platform.cce_build import build_config
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
SHAPE_SIZE_LIMIT = 2147483648
@ -82,8 +81,9 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in ["float16", "float32"]:
if not inp_dtype in check_list:
raise RuntimeError("Dtype of input only support float16, float32")
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
@ -100,7 +100,7 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
te.lang.cce.cce_build_code(sch, config)
return
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
with build_config:
tvm.build(sch, tensor_list, "cce", name=kernel_name)

View File

@ -95,12 +95,12 @@ def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
dtype_mean = d_batch_mean.get("dtype").lower()
if format_data == "NC1HWC0":
if len(shape_x) != 5:
raise RuntimeError("batchnorm_fold grad only support shape 5D"
raise RuntimeError("batchnorm_fold only support shape 5D"
"when input format is NC1HWC0")
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
elif format_data == "NCHW":
if len(shape_x) < 2 or len(shape_x) > 4:
raise RuntimeError("batchnorm_fold grad only support shape 2D to 4D")
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
if shape_x[1] != shape_mean[0]:
raise RuntimeError("data_format is NCHW, shape_bias must"
"be equal to the second axis of shape_x")

View File

@ -66,8 +66,9 @@ def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correctio
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in ["float16", "float32"]:
if not inp_dtype in check_list:
raise RuntimeError("Dtype of input only support float16, float32")
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)

View File

@ -86,9 +86,11 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
return res
def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
kernel_name="fake_quant_perchannel"):
"""Get and check fake_quant_perchannel parameters"""
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format")
@ -118,25 +120,15 @@ def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
return x_shape, shape_c, x_dtype
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
x_shape, shape_c, x_dtype = fake_quant_perchannel_param(x, min_val, max_val,
channel_axis, kernel_name)
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)

View File

@ -110,9 +110,11 @@ def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, qua
return res
def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""Get and check FakeQuantPerChannelGrad parameters"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format")
@ -142,18 +144,6 @@ def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
return x_shape, shape_c, x_dtype
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
@ -163,8 +153,10 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
if narrow_range:
quant_min = quant_min + 1
x_shape, shape_c, x_dtype = fake_quant_perchannel_grad_param(x, min_val, max_val,
channel_axis, kernel_name)
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)

View File

@ -82,6 +82,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
kernel_name="minmax_update_perchannel"):
"""MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")