forked from mindspore-Ecosystem/mindspore
!17785 clean code for thor
Merge pull request !17785 from melody/master
This commit is contained in:
commit
346a35ba7d
|
@ -14,22 +14,3 @@
|
|||
# ============================================================================
|
||||
|
||||
"""custom ops"""
|
||||
from .batchnorm_fold import _batchnorm_fold_tbe
|
||||
from .batchnorm_fold2 import _batchnorm_fold2_tbe
|
||||
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
|
||||
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
|
||||
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
|
||||
from .correction_mul import _correction_mul_tbe
|
||||
from .correction_mul_grad import _correction_mul_grad_tbe
|
||||
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
|
||||
from .fake_learned_scale_quant_perlayer_grad import _fake_learned_scale_quant_perlayer_grad_d_tbe
|
||||
from .fake_learned_scale_quant_perlayer_grad_reduce import _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe
|
||||
from .fake_learned_scale_quant_perchannel import _fake_learned_scale_quant_perchannel_tbe
|
||||
from .fake_learned_scale_quant_perchannel_grad import _fake_learned_scale_quant_perchannel_grad_d_tbe
|
||||
from .fake_learned_scale_quant_perchannel_grad_reduce import _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe
|
||||
from .fake_quant_perchannel import _fake_quant_perchannel_tbe
|
||||
from .fake_quant_perchannel_grad import _fake_quant_perchannel_grad_tbe
|
||||
from .fake_quant_perlayer import _fake_quant_per_layer_tbe
|
||||
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
|
||||
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
|
||||
from .minmax_update_perlayer import _minmax_update_perlayer_tbe
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""
|
||||
copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License == distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
_basic
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
|
|
@ -43,6 +43,25 @@ def _get_flattern_shape(shape):
|
|||
return (flattern_shape,)
|
||||
|
||||
|
||||
def _error_feedback(input_shape):
|
||||
"""error feedback"""
|
||||
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
|
||||
((36, 128, 128), (36, 128, 128), "float32", False, True),
|
||||
((5, 128, 128), (5, 128, 128), "float32", False, True),
|
||||
((18, 128, 128), (18, 128, 128), "float32", False, True),
|
||||
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
||||
((9, 128, 128), (9, 128, 128), "float32", False, True),
|
||||
((1, 64, 64), (1, 64, 64), "float32", False, True),
|
||||
((1, 128, 128), (1, 128, 128), "float32", False, True),
|
||||
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
||||
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
||||
((6, 128, 128), (6, 128, 128), "float32", False, True),
|
||||
((24, 128, 128), (24, 128, 128), "float32", False, True),
|
||||
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
|
||||
if input_shape not in support_shape:
|
||||
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
|
||||
|
||||
|
||||
def _inner_matmul_new(tik_instance, dtype, input_info, res, res_index):
|
||||
"""_inner_matmul_new"""
|
||||
input1, input1_index, input2, input2_index = input_info
|
||||
|
@ -100,9 +119,9 @@ def process_input_shape_640(input_shape, tik_instance, dtype, total_input, res):
|
|||
"""process input shape of 640"""
|
||||
input1, input2 = total_input
|
||||
if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True):
|
||||
with tik_instance.for_range(0, 30, block_num=30) as block_idx,\
|
||||
tik_instance.for_range(0, 11) as cc1_db,\
|
||||
tik_instance.for_range(0, 2, thread_num=2) as thread_idx,\
|
||||
with tik_instance.for_range(0, 30, block_num=30) as block_idx, \
|
||||
tik_instance.for_range(0, 11) as cc1_db, \
|
||||
tik_instance.for_range(0, 2, thread_num=2) as thread_idx, \
|
||||
tik_instance.if_scope(((((block_idx % 6) * 22) + (cc1_db * 2) + thread_idx) < 128)):
|
||||
input_1_local_ub = tik_instance.Tensor(dtype, [128], name="input_1_local_ub",
|
||||
scope=tik.scope_ubuf)
|
||||
|
@ -157,17 +176,13 @@ def process_input_shape_1152(input_shape, tik_instance, dtype, total_input, res)
|
|||
input2_index = (block_idx // 3) * 16384
|
||||
res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128
|
||||
input_info = input1, input1_index, input2, input2_index
|
||||
_inner_matmul_new(tik_instance, dtype,
|
||||
input_info,
|
||||
res, res_index)
|
||||
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
|
||||
with tik_instance.if_scope((block_idx % 3) < 2):
|
||||
input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128
|
||||
input2_index = (block_idx // 3) * 16384
|
||||
res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128
|
||||
input_info = input1, input1_index, input2, input2_index
|
||||
_inner_matmul_new(tik_instance, dtype,
|
||||
input_info,
|
||||
res, res_index)
|
||||
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
|
||||
|
||||
|
||||
@op_info_register(cus_batchmatmul_op_info)
|
||||
|
@ -185,23 +200,9 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
|
|||
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % (
|
||||
dtype, input_x2.get("dtype").lower()))
|
||||
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b)
|
||||
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
|
||||
((36, 128, 128), (36, 128, 128), "float32", False, True),
|
||||
((5, 128, 128), (5, 128, 128), "float32", False, True),
|
||||
((18, 128, 128), (18, 128, 128), "float32", False, True),
|
||||
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
||||
((9, 128, 128), (9, 128, 128), "float32", False, True),
|
||||
((1, 64, 64), (1, 64, 64), "float32", False, True),
|
||||
((1, 128, 128), (1, 128, 128), "float32", False, True),
|
||||
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
||||
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
||||
((6, 128, 128), (6, 128, 128), "float32", False, True),
|
||||
((24, 128, 128), (24, 128, 128), "float32", False, True),
|
||||
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
|
||||
if input_shape not in support_shape:
|
||||
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
|
||||
|
||||
# if not transpose_a and transpose_b:
|
||||
_error_feedback(input_shape)
|
||||
|
||||
batch, m, k = x1_shape
|
||||
|
||||
input1_shape = _get_flattern_shape(x1_shape)
|
||||
|
@ -215,14 +216,13 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
|
|||
|
||||
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True):
|
||||
with tik_instance.for_range(0, 18, block_num=18) as block_idx, \
|
||||
tik_instance.for_range(0, 2) as cc0,\
|
||||
tik_instance.for_range(0, 2) as cc0, \
|
||||
tik_instance.for_range(0, 128, thread_num=2) as cc1:
|
||||
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
|
||||
input2_index = block_idx * 32768 + cc0 * 16384
|
||||
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
|
||||
input_info = input1, input1_index, input2, input2_index
|
||||
_inner_matmul_new(tik_instance, dtype,
|
||||
input_info, res, res_index)
|
||||
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
|
||||
|
||||
total_input = input1, input2
|
||||
process_input_shape_640(input_shape, tik_instance, dtype, total_input, res)
|
||||
|
@ -234,20 +234,18 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
|
|||
input2_index = block_idx * 16384
|
||||
res_index = block_idx * 16384 + cc0 * 128
|
||||
input_info = input1, input1_index, input2, input2_index
|
||||
_inner_matmul_new(tik_instance, dtype,
|
||||
input_info, res, res_index)
|
||||
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
|
||||
|
||||
process_input_shape_1152(input_shape, tik_instance, dtype, total_input, res)
|
||||
|
||||
if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True):
|
||||
with tik_instance.for_range(0, 32, block_num=32) as block_idx,\
|
||||
with tik_instance.for_range(0, 32, block_num=32) as block_idx, \
|
||||
tik_instance.for_range(0, 2, thread_num=2) as cc0:
|
||||
input1_index = block_idx * 128 + cc0 * 64
|
||||
input2_index = 0
|
||||
res_index = block_idx * 128 + cc0 * 64
|
||||
input_info = input1, input1_index, input2, input2_index
|
||||
_inner_matmul_new_1_64_32_64(tik_instance, dtype,
|
||||
input_info,
|
||||
_inner_matmul_new_1_64_32_64(tik_instance, dtype, input_info,
|
||||
res, res_index)
|
||||
|
||||
input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True),
|
||||
|
@ -257,14 +255,12 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
|
|||
((8, 128, 128), (8, 128, 128), "float32", False, True),
|
||||
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
||||
((24, 128, 128), (24, 128, 128), "float32", False, True),
|
||||
((32, 128, 128), (32, 128, 128), 'float32', False, True)
|
||||
]
|
||||
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
|
||||
if input_shape in input_shape_list:
|
||||
block_num = 32
|
||||
block_num, thread_num = 32, 2
|
||||
input1_unit_size = 128
|
||||
input2_unint_size = 128 * 128
|
||||
block_process_ele_num = (batch * m * k) // block_num
|
||||
thread_num = 2
|
||||
loop_time = (batch * m * k) // block_num // input1_unit_size
|
||||
with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx, \
|
||||
tik_instance.for_range(0, loop_time, thread_num=thread_num) as cc0:
|
||||
|
|
|
@ -28,11 +28,11 @@ batch_norm_op_info = TBERegOp("BatchNormFoldD") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold") \
|
||||
.partial_flag(True) \
|
||||
.attr("momentum", "optional", "float", "all", "0.9") \
|
||||
.attr("epsilon", "optional", "float", "all", "0.00001") \
|
||||
.attr("is_training", "optional", "bool", "all", "true") \
|
||||
.attr("freeze_bn", "optional", "int", "all", "0") \
|
||||
.attr("format", "optional", "str", "all", "NCHW") \
|
||||
.attr("momentum", "optional", "float", "all") \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.attr("freeze_bn", "optional", "int", "all") \
|
||||
.attr("format", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "x_sum", False, "required", "all") \
|
||||
.input(2, "x_square_sum", False, "required", "all") \
|
||||
|
@ -57,6 +57,43 @@ def _batchnorm_fold_tbe():
|
|||
return
|
||||
|
||||
|
||||
def _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon):
|
||||
"""_batchnorm_fold_compute"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict,
|
||||
dict, dict, dict, dict, dict, dict, dict,
|
||||
float, float, bool, int, str, str)
|
||||
|
@ -108,39 +145,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
|||
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
|
||||
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
|
||||
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
|
||||
res = _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon)
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
config = {"name": kernel_name,
|
||||
|
|
|
@ -21,6 +21,7 @@ from te.platform.cce_build import build_config
|
|||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
|
@ -81,9 +82,8 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
|
|||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
if not inp_dtype in ["float16", "float32"]:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
@ -100,7 +100,7 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
|
|||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
||||
return
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
|
||||
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
|
||||
with build_config:
|
||||
tvm.build(sch, tensor_list, "cce", name=kernel_name)
|
||||
|
|
|
@ -95,12 +95,12 @@ def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
|
|||
dtype_mean = d_batch_mean.get("dtype").lower()
|
||||
if format_data == "NC1HWC0":
|
||||
if len(shape_x) != 5:
|
||||
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||
raise RuntimeError("batchnorm_fold grad only support shape 5D"
|
||||
"when input format is NC1HWC0")
|
||||
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||
elif format_data == "NCHW":
|
||||
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||
raise RuntimeError("batchnorm_fold grad only support shape 2D to 4D")
|
||||
if shape_x[1] != shape_mean[0]:
|
||||
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||
"be equal to the second axis of shape_x")
|
||||
|
|
|
@ -66,9 +66,8 @@ def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correctio
|
|||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
if not inp_dtype in ["float16", "float32"]:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
|
|
@ -86,11 +86,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
|
|||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel(x, min_val, max_val, y,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""FakeQuantPerChannel"""
|
||||
def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""Get and check fake_quant_perchannel parameters"""
|
||||
x_shape = x.get("shape")
|
||||
x_shape_ = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
|
@ -120,15 +118,25 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel(x, min_val, max_val, y,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""FakeQuantPerChannel"""
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
x_shape, shape_c, x_dtype = fake_quant_perchannel_param(x, min_val, max_val,
|
||||
channel_axis, kernel_name)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||
|
|
|
@ -110,11 +110,9 @@ def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, qua
|
|||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""FakeQuantPerChannelGrad"""
|
||||
def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""Get and check FakeQuantPerChannelGrad parameters"""
|
||||
x_shape = x.get("shape")
|
||||
x_shape_ = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
|
@ -144,6 +142,18 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""FakeQuantPerChannelGrad"""
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
|
@ -153,10 +163,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
x_shape, shape_c, x_dtype = fake_quant_perchannel_grad_param(x, min_val, max_val,
|
||||
channel_axis, kernel_name)
|
||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
|
|
|
@ -54,6 +54,25 @@ def _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res
|
|||
return tik_instance, res
|
||||
|
||||
|
||||
def _error_feedback(input_info):
|
||||
"""error feedback"""
|
||||
support_shape = [((1, 128, 128), "float32"),
|
||||
((2, 128, 128), "float32"),
|
||||
((4, 128, 128), "float32"),
|
||||
((8, 128, 128), "float32"),
|
||||
((16, 128, 128), "float32"),
|
||||
((5, 128, 128), "float32"),
|
||||
((9, 128, 128), "float32"),
|
||||
((18, 128, 128), "float32"),
|
||||
((36, 128, 128), "float32"),
|
||||
((32, 128, 128), "float32"),
|
||||
((1, 64, 64), "float32"),
|
||||
((32, 64), "float32")
|
||||
]
|
||||
if input_info not in support_shape:
|
||||
raise RuntimeError("input_shape %s is not supported" % str(input_info))
|
||||
|
||||
|
||||
def shape0(tik_instance, input_x_shape, input_x, res):
|
||||
"""shape0"""
|
||||
total_elements0 = 1
|
||||
|
@ -482,25 +501,13 @@ def cus_fused_abs_max1(input_x, output, origin_shape=None, kernel_name="cus_fuse
|
|||
|
||||
tik_instance = _get_tik_instance()
|
||||
|
||||
support_shape = [((1, 128, 128), "float32"),
|
||||
((2, 128, 128), "float32"),
|
||||
((4, 128, 128), "float32"),
|
||||
((8, 128, 128), "float32"),
|
||||
((16, 128, 128), "float32"),
|
||||
((5, 128, 128), "float32"),
|
||||
((9, 128, 128), "float32"),
|
||||
((18, 128, 128), "float32"),
|
||||
((36, 128, 128), "float32"),
|
||||
((32, 128, 128), "float32"),
|
||||
((1, 64, 64), "float32"),
|
||||
((32, 64), "float32")
|
||||
]
|
||||
ori_shape = tuple(origin_shape)
|
||||
input_info = (tuple(input_x_shape), dtype)
|
||||
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
|
||||
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
|
||||
if input_info not in support_shape:
|
||||
raise RuntimeError("input_shape %s is not supported" % str(input_info))
|
||||
|
||||
_error_feedback(input_info)
|
||||
|
||||
if input_info == ((1, 128, 128), "float32"):
|
||||
tik_instance, res = shape0(tik_instance, input_x_shape, input_x, res)
|
||||
elif input_info == ((2, 128, 128), "float32"):
|
||||
|
@ -534,8 +541,6 @@ def cus_fused_abs_max1(input_x, output, origin_shape=None, kernel_name="cus_fuse
|
|||
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
|
||||
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
|
||||
tik_instance.data_move(res[0], input_x_ub, 0, 1, 1, 0, 0)
|
||||
else:
|
||||
raise RuntimeError("UnSupportedShape")
|
||||
|
||||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res])
|
||||
return tik_instance
|
||||
|
|
|
@ -35,7 +35,7 @@ cus_img2col_info = TBERegOp("CusImg2Col") \
|
|||
|
||||
|
||||
def shape56_0(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1))"""
|
||||
"""input_shape is ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 56, 56, 1, 1
|
||||
|
@ -62,7 +62,7 @@ def shape56_0(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape56_1(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2))"""
|
||||
"""input_shape is ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 56, 56, 1, 1
|
||||
|
@ -90,7 +90,7 @@ def shape56_1(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape56_2(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 56, 0, 1, 1
|
||||
|
@ -114,7 +114,7 @@ def shape56_2(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape56_3(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 56, 0, 1, 1
|
||||
|
@ -142,7 +142,7 @@ def shape56_3(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape56_4(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2))"""
|
||||
"""input_shape is ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 56, 0, 1, 1
|
||||
|
@ -171,7 +171,7 @@ def shape56_4(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape28_0(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1))"""
|
||||
"""input_shape is ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
|
||||
|
@ -199,7 +199,7 @@ def shape28_0(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape28_1(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2))"""
|
||||
"""input_shape is ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
|
||||
|
@ -238,7 +238,7 @@ def shape28_1(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape28_2(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2))"""
|
||||
"""input_shape is ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
|
||||
|
@ -286,7 +286,7 @@ def shape28_2(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape28_3(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
|
||||
|
@ -314,7 +314,7 @@ def shape28_3(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape28_4(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
|
||||
|
@ -342,7 +342,7 @@ def shape28_4(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape14_0(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1))"""
|
||||
"""input_shape is ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
|
||||
|
@ -382,7 +382,7 @@ def shape14_0(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape14_1(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2))"""
|
||||
"""input_shape is ((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
|
||||
|
@ -419,7 +419,7 @@ def shape14_1(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape14_2(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2))"""
|
||||
"""input_shape is ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
|
||||
|
@ -454,7 +454,7 @@ def shape14_2(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
|
||||
|
@ -476,7 +476,7 @@ def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
|
|||
with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i:
|
||||
rep = 13
|
||||
c1_index = 0
|
||||
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h = 0, 0, 0, 0
|
||||
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h, rep, c1_index = 0, 0, 0, 0, 13, 0
|
||||
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[3328 * (i - eeb1 * 16)],
|
||||
input_1_1_local_l1[3136 * i],
|
||||
pad, l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
|
||||
|
@ -492,9 +492,7 @@ def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
with tik_instance.for_range(0, 2) as eeb1:
|
||||
with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i:
|
||||
rep = 13
|
||||
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h = 0, 0, 0, 0
|
||||
c1_index = 0
|
||||
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h, rep, c1_index = 0, 0, 0, 0, 13, 0
|
||||
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[3328 * (i - eeb1 * 16)],
|
||||
input_1_2_local_l1[3136 * i],
|
||||
pad, l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
|
||||
|
@ -511,7 +509,7 @@ def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape14_4(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h = 14
|
||||
|
@ -551,7 +549,7 @@ def shape14_4(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape7_0(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1))"""
|
||||
"""input_shape is ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [1, 1, 1, 1]
|
||||
l1_h = 7
|
||||
|
@ -576,36 +574,34 @@ def shape7_0(tik_instance, input_x, res, input_shape, shape_info):
|
|||
left_top_h = -1
|
||||
c1_index = 0
|
||||
with tik_instance.for_range(0, 32) as i:
|
||||
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[1024 * i], input_1_1_local_l1[784 * i],
|
||||
pad, l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
|
||||
left_top_w, left_top_h, stride_w, stride_h, filter_w,
|
||||
filter_h, dilation_filter_w, dilation_filter_h,
|
||||
jump_stride, repeat_mode, rep)
|
||||
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[1024 * i],
|
||||
input_1_1_local_l1[784 * i], pad, l1_h, l1_w, c1_index,
|
||||
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h,
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w,
|
||||
dilation_filter_h, jump_stride, repeat_mode, rep)
|
||||
with tik_instance.for_range(0, 32) as i:
|
||||
tik_instance.data_move(input_1_2_fractal_l1_local_ub[i * 49 * 16],
|
||||
input_1_1_fractal_l1_local_ub[i * 1024], 0, 1, 49, 0, 0)
|
||||
|
||||
with tik_instance.for_range(0, 98) as i:
|
||||
tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0], input_1_2_fractal_l1_local_ub[256 * i],
|
||||
0, 1, 16, 0, 0)
|
||||
tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0],
|
||||
input_1_2_fractal_l1_local_ub[256 * i], 0, 1, 16, 0, 0)
|
||||
return tik_instance, res
|
||||
|
||||
|
||||
def shape7_1(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h = 7
|
||||
l1_w = 7
|
||||
jump_stride = 1
|
||||
repeat_mode = 1
|
||||
l1_h, l1_w, jump_stride, repeat_mode = 7, 7, 1, 1
|
||||
with tik_instance.for_range(0, 32, block_num=32) as block_index:
|
||||
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
|
||||
name="input_1_2_fractal_l1_local_ub")
|
||||
input_1_1_local_l1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf,
|
||||
name="input_1_1_local_l1")
|
||||
input_1_1_fractal_l1_local_ub = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf,
|
||||
name="input_1_1_fractal_l1_local_ub")
|
||||
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
|
||||
name="input_1_2_fractal_l1_local_ub")
|
||||
|
||||
with tik_instance.for_range(0, 4) as eeb0, tik_instance.for_range(0, 32) as i:
|
||||
tik_instance.data_move(input_1_1_local_l1[i * 784], input_x[i, eeb0 + block_index * 4, 0, 0, 0], 0,
|
||||
1, 49, 0, 0)
|
||||
|
@ -633,21 +629,21 @@ def shape7_1(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def shape7_2(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape == ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
|
||||
"""input_shape is ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [0, 0, 0, 0]
|
||||
l1_h = 7
|
||||
l1_w = 7
|
||||
c1_index = 0
|
||||
jump_stride = 1
|
||||
repeat_mode = 1
|
||||
jump_stride, repeat_mode = 1, 1
|
||||
with tik_instance.for_range(0, 32, block_num=32) as block_index:
|
||||
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
|
||||
name="input_1_2_fractal_l1_local_ub")
|
||||
input_1_1_local_l1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf,
|
||||
name="input_1_1_local_l1")
|
||||
input_1_1_fractal_l1_local_ub = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf,
|
||||
name="input_1_1_fractal_l1_local_ub")
|
||||
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
|
||||
name="input_1_2_fractal_l1_local_ub")
|
||||
|
||||
|
||||
with tik_instance.for_range(0, 32) as i:
|
||||
tik_instance.data_move(input_1_1_local_l1[i * 784], input_x[i, block_index, 0, 0, 0], 0, 1, 49, 0, 0)
|
||||
|
@ -673,14 +669,10 @@ def shape7_2(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
|
||||
def height224_width224(tik_instance, input_x, res, input_shape, shape_info):
|
||||
"""input_shape = ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2))"""
|
||||
"""input_shape is ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2))"""
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
|
||||
pad = [3, 3, 3, 3]
|
||||
l1_h = 56
|
||||
l1_w = 224
|
||||
c1_index = 0
|
||||
jump_stride = 1
|
||||
repeat_mode = 1
|
||||
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 224, 0, 1, 1
|
||||
with tik_instance.for_range(0, 32, block_num=32) as block_index:
|
||||
input_1_1_local_l1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf,
|
||||
name="input_1_1_local_l1")
|
||||
|
@ -717,11 +709,10 @@ def height224_width224(tik_instance, input_x, res, input_shape, shape_info):
|
|||
|
||||
left_top_h = 1 + ((55 - temp - (-3 + eeb)) // 2 - 29) * 2
|
||||
|
||||
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub, input_1_1_local_l1, pad,
|
||||
l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
|
||||
left_top_w, left_top_h, stride_w, stride_h, filter_w,
|
||||
filter_h, dilation_filter_w, dilation_filter_h, jump_stride,
|
||||
repeat_mode, rep)
|
||||
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub, input_1_1_local_l1, pad, l1_h, l1_w,
|
||||
c1_index, fetch_filter_w, fetch_filter_h, left_top_w, left_top_h,
|
||||
stride_w, stride_h, filter_w, filter_h, dilation_filter_w,
|
||||
dilation_filter_h, jump_stride, repeat_mode, rep)
|
||||
with tik_instance.for_range(0, rep) as cc1:
|
||||
tik_instance.data_move(
|
||||
res[cc0 + eeb * 7, cc1 + rep_prefix + (eeb0 - 1) * rep + 784 * block_index, 0, 0],
|
||||
|
|
|
@ -46,6 +46,7 @@ matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \
|
|||
.dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
def shape_gen1(input_x1, input_x2, output_y, kernel_name, trans_a, trans_b):
|
||||
"""shape gen1"""
|
||||
shape_a = input_x1.get("ori_shape")
|
||||
|
@ -94,6 +95,7 @@ def shape_gen1(input_x1, input_x2, output_y, kernel_name, trans_a, trans_b):
|
|||
trans_b = bool(1 - trans_b)
|
||||
return shape_a, shape_b, trans_a, trans_b, shape_output
|
||||
|
||||
|
||||
def shape_gen2(bias, input_x1, output_y, shape_a, shape_b, trans_a, trans_b):
|
||||
"""shape gen2"""
|
||||
shape_bias = ()
|
||||
|
@ -144,6 +146,7 @@ def shape_gen2(bias, input_x1, output_y, shape_a, shape_b, trans_a, trans_b):
|
|||
format_b = "FRACTAL_NZ"
|
||||
return shape_a_temp, format_a, shape_b_temp, format_b, shape_bias, src_dtype, dst_dtype
|
||||
|
||||
|
||||
def core(shape_a_temp, shape_b_temp, shape_output, kernel_name):
|
||||
"""core func"""
|
||||
if util.get_product_version() == util.VERSION_MINI:
|
||||
|
@ -206,6 +209,7 @@ def core(shape_a_temp, shape_b_temp, shape_output, kernel_name):
|
|||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resmatmul])
|
||||
return tik_instance
|
||||
|
||||
|
||||
@op_info_register(matmul_cube_dense_left_op_info)
|
||||
def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, trans_a=False, trans_b=False,
|
||||
kernel_name="cus_matmul_cube_dense_left"):
|
||||
|
|
|
@ -68,106 +68,133 @@ def cus_matmul_cube_dense_right(input_x1, input_x2, input_x3, output_y=None,
|
|||
core_m_idx = block_index // 16
|
||||
core_n_idx = block_index % 16
|
||||
matrix_max_scalar = tik_instance.Scalar("float32")
|
||||
matrix_max_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf,
|
||||
name="matrix_max_local_UB")
|
||||
tik_instance.data_move(matrix_max_local_UB, input_x3, 0, 1, 1, 0, 0)
|
||||
matrix_max_scalar.set_as(matrix_max_local_UB[0])
|
||||
|
||||
resmatmul_local_ub = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
|
||||
name="resmatmul_local_ub")
|
||||
resmatmul_local_ub1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf,
|
||||
name="resmatmul_local_ub1")
|
||||
|
||||
resmatmul_local_ub_local_l0c = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
|
||||
name="resmatmul_local_ub_local_l0c")
|
||||
resmatmul_local_ub_local_l0c1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc,
|
||||
name="resmatmul_local_ub_local_l0c1")
|
||||
|
||||
input_1_local_l1_local_L0A = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
|
||||
name="input_1_local_l1_local_L0A")
|
||||
input_2_local_l1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_2_local_l1")
|
||||
input_2_local_l11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_2_local_l11")
|
||||
|
||||
input_1_local_l1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_1_local_l1")
|
||||
input_1_local_l11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_1_local_l11")
|
||||
|
||||
input_2_local_l1_local_l0b = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
|
||||
name="input_2_local_l1_local_l0b")
|
||||
input_2_local_l1_local_l0b1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
|
||||
name="input_2_local_l1_local_l0b1")
|
||||
matrix_max_local_ub = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf,
|
||||
name="matrix_max_local_ub")
|
||||
tik_instance.data_move(matrix_max_local_ub, input_x3, 0, 1, 1, 0, 0)
|
||||
matrix_max_scalar.set_as(matrix_max_local_ub[0])
|
||||
|
||||
with tik_instance.if_scope(core_m_idx == 0):
|
||||
with tik_instance.for_range(0, 2) as cc1:
|
||||
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8,
|
||||
128, 1920, 0)
|
||||
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256,
|
||||
752, 0)
|
||||
with tik_instance.for_range(0, 8) as cc10:
|
||||
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0,
|
||||
8, 8, 0, True)
|
||||
with tik_instance.for_range(0, 16) as cc101:
|
||||
tik_instance.load2dv1(input_1_local_l1_local_L0A[cc101 * 2048], input_1_local_l1[cc101 * 256],
|
||||
0, 8, 16, 0, False)
|
||||
|
||||
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_L0A,
|
||||
input_2_local_l1_local_l0b, 256, 128, 128, 0)
|
||||
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64],
|
||||
matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64],
|
||||
matrix_max_scalar, 2, 1, 1, 8, 8)
|
||||
|
||||
tik_instance.data_move(resmatmul[core_n_idx * 129024 + cc1 * 4096], resmatmul_local_ub, 0, 8, 512,
|
||||
0, 1504)
|
||||
tik_instance, resmatmul = _update_tik1(tik_instance, input_x1, input_x2, core_n_idx,
|
||||
resmatmul, matrix_max_scalar)
|
||||
with tik_instance.else_scope():
|
||||
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
|
||||
1920, 0)
|
||||
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0)
|
||||
with tik_instance.for_range(0, 8) as cc10:
|
||||
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0, 8,
|
||||
8, 0, True)
|
||||
with tik_instance.for_range(0, 16) as cc101:
|
||||
tik_instance.load2dv1(input_1_local_l1_local_L0A[cc101 * 2048], input_1_local_l1[cc101 * 256],
|
||||
0, 8, 16, 0, False)
|
||||
|
||||
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_L0A, input_2_local_l1_local_l0b,
|
||||
256, 128, 128, 0)
|
||||
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64], matrix_max_scalar,
|
||||
255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64], matrix_max_scalar,
|
||||
2, 1, 1, 8, 8)
|
||||
|
||||
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 2 * 4096], resmatmul_local_ub, 0, 8, 512, 0,
|
||||
1504)
|
||||
|
||||
tik_instance.data_move(input_2_local_l11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
|
||||
1920, 0)
|
||||
tik_instance.data_move(input_1_local_l11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0)
|
||||
|
||||
with tik_instance.for_range(0, 8) as cc102:
|
||||
tik_instance.load2dv1(input_2_local_l1_local_l0b1[cc102 * 2048], input_2_local_l11[cc102 * 256], 0,
|
||||
8, 8, 0, True)
|
||||
with tik_instance.for_range(0, 16) as cc103:
|
||||
tik_instance.load2dv1(input_1_local_l1_local_L0A[cc103 * 2048], input_1_local_l11[cc103 * 256], 0,
|
||||
8, 15, 0, False)
|
||||
|
||||
tik_instance.mmad(resmatmul_local_ub_local_l0c1, input_1_local_l1_local_L0A,
|
||||
input_2_local_l1_local_l0b1, 240, 128, 128, 0)
|
||||
tik_instance.data_move(resmatmul_local_ub1, resmatmul_local_ub_local_l0c1, 0, 1, 120, 0, 0)
|
||||
|
||||
tik_instance.vmuls(64, resmatmul_local_ub1, resmatmul_local_ub1, matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub1[255 * 64], resmatmul_local_ub1[255 * 64], matrix_max_scalar,
|
||||
225, 1, 1, 8, 8)
|
||||
|
||||
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 12288], resmatmul_local_ub1, 0, 8, 480, 0, 1536)
|
||||
tik_instance, resmatmul = _update_tik2(tik_instance, input_x1, input_x2, core_n_idx,
|
||||
resmatmul, matrix_max_scalar)
|
||||
|
||||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resmatmul])
|
||||
return tik_instance
|
||||
return None
|
||||
|
||||
|
||||
def _update_tik1(tik_instance, input_x1, input_x2, core_n_idx, resmatmul, matrix_max_scalar):
|
||||
"""_update_tik1"""
|
||||
resmatmul_local_ub = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
|
||||
name="resmatmul_local_ub")
|
||||
resmatmul_local_ub_local_l0c = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
|
||||
name="resmatmul_local_ub_local_l0c")
|
||||
input_1_local_l1_local_l0a = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
|
||||
name="input_1_local_l1_local_l0a")
|
||||
input_2_local_l1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_2_local_l1")
|
||||
|
||||
input_1_local_l1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_1_local_l1")
|
||||
input_2_local_l1_local_l0b = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
|
||||
name="input_2_local_l1_local_l0b")
|
||||
with tik_instance.for_range(0, 2) as cc1:
|
||||
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8,
|
||||
128, 1920, 0)
|
||||
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256,
|
||||
752, 0)
|
||||
with tik_instance.for_range(0, 8) as cc10:
|
||||
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0,
|
||||
8, 8, 0, True)
|
||||
with tik_instance.for_range(0, 16) as cc101:
|
||||
tik_instance.load2dv1(input_1_local_l1_local_l0a[cc101 * 2048], input_1_local_l1[cc101 * 256],
|
||||
0, 8, 16, 0, False)
|
||||
|
||||
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_l0a,
|
||||
input_2_local_l1_local_l0b, 256, 128, 128, 0)
|
||||
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64],
|
||||
matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64],
|
||||
matrix_max_scalar, 2, 1, 1, 8, 8)
|
||||
|
||||
tik_instance.data_move(resmatmul[core_n_idx * 129024 + cc1 * 4096], resmatmul_local_ub, 0, 8, 512,
|
||||
0, 1504)
|
||||
return tik_instance, resmatmul
|
||||
|
||||
|
||||
def _update_tik2(tik_instance, input_x1, input_x2, core_n_idx, resmatmul, matrix_max_scalar):
|
||||
"""_update_tik2"""
|
||||
resmatmul_local_ub = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
|
||||
name="resmatmul_local_ub")
|
||||
resmatmul_local_ub1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf,
|
||||
name="resmatmul_local_ub1")
|
||||
|
||||
resmatmul_local_ub_local_l0c = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
|
||||
name="resmatmul_local_ub_local_l0c")
|
||||
resmatmul_local_ub_local_l0c1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc,
|
||||
name="resmatmul_local_ub_local_l0c1")
|
||||
|
||||
input_1_local_l1_local_l0a = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
|
||||
name="input_1_local_l1_local_l0a")
|
||||
input_2_local_l1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_2_local_l1")
|
||||
input_2_local_l11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_2_local_l11")
|
||||
|
||||
input_1_local_l1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_1_local_l1")
|
||||
input_1_local_l11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf,
|
||||
name="input_1_local_l11")
|
||||
|
||||
input_2_local_l1_local_l0b = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
|
||||
name="input_2_local_l1_local_l0b")
|
||||
input_2_local_l1_local_l0b1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
|
||||
name="input_2_local_l1_local_l0b1")
|
||||
|
||||
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
|
||||
1920, 0)
|
||||
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0)
|
||||
with tik_instance.for_range(0, 8) as cc10:
|
||||
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0, 8,
|
||||
8, 0, True)
|
||||
with tik_instance.for_range(0, 16) as cc101:
|
||||
tik_instance.load2dv1(input_1_local_l1_local_l0a[cc101 * 2048], input_1_local_l1[cc101 * 256],
|
||||
0, 8, 16, 0, False)
|
||||
|
||||
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_l0a, input_2_local_l1_local_l0b,
|
||||
256, 128, 128, 0)
|
||||
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64], matrix_max_scalar,
|
||||
255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64], matrix_max_scalar,
|
||||
2, 1, 1, 8, 8)
|
||||
|
||||
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 2 * 4096], resmatmul_local_ub, 0, 8, 512, 0,
|
||||
1504)
|
||||
|
||||
tik_instance.data_move(input_2_local_l11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
|
||||
1920, 0)
|
||||
tik_instance.data_move(input_1_local_l11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0)
|
||||
|
||||
with tik_instance.for_range(0, 8) as cc102:
|
||||
tik_instance.load2dv1(input_2_local_l1_local_l0b1[cc102 * 2048], input_2_local_l11[cc102 * 256], 0,
|
||||
8, 8, 0, True)
|
||||
with tik_instance.for_range(0, 16) as cc103:
|
||||
tik_instance.load2dv1(input_1_local_l1_local_l0a[cc103 * 2048], input_1_local_l11[cc103 * 256], 0,
|
||||
8, 15, 0, False)
|
||||
|
||||
tik_instance.mmad(resmatmul_local_ub_local_l0c1, input_1_local_l1_local_l0a,
|
||||
input_2_local_l1_local_l0b1, 240, 128, 128, 0)
|
||||
tik_instance.data_move(resmatmul_local_ub1, resmatmul_local_ub_local_l0c1, 0, 1, 120, 0, 0)
|
||||
|
||||
tik_instance.vmuls(64, resmatmul_local_ub1, resmatmul_local_ub1, matrix_max_scalar, 255, 1, 1, 8, 8)
|
||||
tik_instance.vmuls(64, resmatmul_local_ub1[255 * 64], resmatmul_local_ub1[255 * 64], matrix_max_scalar,
|
||||
225, 1, 1, 8, 8)
|
||||
|
||||
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 12288], resmatmul_local_ub1, 0, 8, 480, 0, 1536)
|
||||
return tik_instance, resmatmul
|
||||
|
|
|
@ -44,6 +44,31 @@ matmul_cube_fracz_left_cast_op_info = TBERegOp("CusMatMulCubeFraczLeftCast") \
|
|||
.get_op_info()
|
||||
|
||||
|
||||
def _clip_num(num):
|
||||
"""clip number"""
|
||||
if num == 0:
|
||||
num = 1
|
||||
return num
|
||||
|
||||
|
||||
def _get_block(trans_a, trans_b, m_shape, n_shape, km_shape, kn_shape):
|
||||
"""_get_block"""
|
||||
block_in0 = cce.BLOCK_IN
|
||||
block_out0 = cce.BLOCK_OUT
|
||||
if trans_a and km_shape == 1:
|
||||
block_in0 = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_a and m_shape == 1:
|
||||
block_in0 = cce.BLOCK_VECTOR
|
||||
|
||||
if trans_b and kn_shape == 1:
|
||||
block_out0 = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_b and n_shape == 1:
|
||||
block_out0 = cce.BLOCK_VECTOR
|
||||
return block_in0, block_out0
|
||||
|
||||
|
||||
@op_info_register(matmul_cube_fracz_left_cast_op_info)
|
||||
def cus_matmul_cube_fraczleftcast(input_x1, input_x2, bias=None, output_y=None, trans_a=False, trans_b=False,
|
||||
kernel_name="cus_matmul_cube_fraczleftcast"):
|
||||
|
@ -52,132 +77,113 @@ def cus_matmul_cube_fraczleftcast(input_x1, input_x2, bias=None, output_y=None,
|
|||
data with fractal format.
|
||||
|
||||
Parameters:
|
||||
shape_a: list or tuple
|
||||
shape_aa: list or tuple
|
||||
Shape of the first tensor a with rank > 1
|
||||
shape_b: list or tuple
|
||||
shape_bb: list or tuple
|
||||
Shape of the second tensor b with the same type with a,
|
||||
and shape_a, shape_b must be 2 dims
|
||||
and shape_aa, shape_bb must be 2 dims
|
||||
src_dtype: str
|
||||
The data type of input, support "float32", "float16"
|
||||
dst_dtype: str
|
||||
The data type of output, support "float32", "float16"
|
||||
trans_a: bool
|
||||
If True, shape_a == transposed before multiplication
|
||||
If True, shape_aa == transposed before multiplication
|
||||
trans_b: bool
|
||||
If True, shape_b == transposed before multiplication
|
||||
If True, shape_bb == transposed before multiplication
|
||||
is_fractal: bool
|
||||
If True, the input data format of a and b must be fractal format
|
||||
shape_bias: list or tuple
|
||||
shape_bbias: list or tuple
|
||||
Shape of bias, only support the input data format with ND
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
shape_a = input_x1.get("ori_shape")
|
||||
shape_b = input_x2.get("ori_shape")
|
||||
print("============")
|
||||
print(input_x1.get("format"), input_x2.get("format"))
|
||||
print(shape_a, shape_b)
|
||||
print("============")
|
||||
shape_aa = input_x1.get("ori_shape")
|
||||
shape_bb = input_x2.get("ori_shape")
|
||||
if input_x2.get("format") == "FRACTAL_Z":
|
||||
n, c, h, w = shape_b
|
||||
n, c, h, w = shape_bb
|
||||
c0 = 16
|
||||
c1 = c // c0
|
||||
if c1 == 0:
|
||||
c1 = 1
|
||||
shape_b = [n, c1 * h * w * c0]
|
||||
shape_a = [n, n]
|
||||
c1 = _clip_num(c1)
|
||||
shape_bb = [n, c1 * h * w * c0]
|
||||
shape_aa = [n, n]
|
||||
|
||||
if input_x1.get("format") == "FRACTAL_Z":
|
||||
n, c, h, w = shape_a
|
||||
n, c, h, w = shape_aa
|
||||
c0 = 16
|
||||
c1 = c // c0
|
||||
if c1 == 0:
|
||||
c1 = 1
|
||||
shape_a = [n, c1 * h * w * c0]
|
||||
shape_b = [c1 * h * w * c0, c1 * h * w * c0]
|
||||
c1 = _clip_num(c1)
|
||||
shape_aa = [n, c1 * h * w * c0]
|
||||
shape_bb = [c1 * h * w * c0, c1 * h * w * c0]
|
||||
|
||||
if input_x2.get("format") == "FRACTAL_NZ":
|
||||
shape_a = [shape_b[0], shape_b[0]]
|
||||
shape_aa = [shape_bb[0], shape_bb[0]]
|
||||
|
||||
if input_x1.get("format") == "FRACTAL_NZ":
|
||||
shape_b = [shape_a[1], shape_a[1]]
|
||||
shape_bb = [shape_aa[1], shape_aa[1]]
|
||||
|
||||
shape_a = list(shape_a)
|
||||
shape_b = list(shape_b)
|
||||
shape_aa = list(shape_aa)
|
||||
shape_bb = list(shape_bb)
|
||||
|
||||
shape_a = _get_input_shape(shape_a)
|
||||
shape_b = _get_input_shape(shape_b)
|
||||
shape_aa = _get_input_shape(shape_aa)
|
||||
shape_bb = _get_input_shape(shape_bb)
|
||||
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape_a)
|
||||
util.check_shape_rule(shape_b)
|
||||
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
|
||||
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
|
||||
util.check_shape_rule(shape_aa)
|
||||
util.check_shape_rule(shape_bb)
|
||||
util.check_shape_size(shape_aa, SHAPE_SIZE_LIMIT)
|
||||
util.check_shape_size(shape_bb, SHAPE_SIZE_LIMIT)
|
||||
|
||||
shape_a = [shape_a[1], shape_a[0]]
|
||||
shape_aa = [shape_aa[1], shape_aa[0]]
|
||||
trans_a = bool(1 - trans_a)
|
||||
|
||||
shape_b = [shape_b[1], shape_b[0]]
|
||||
shape_bb = [shape_bb[1], shape_bb[0]]
|
||||
trans_b = bool(1 - trans_b)
|
||||
|
||||
shape_bias = ()
|
||||
shape_bbias = ()
|
||||
if bias is not None and bool(bias):
|
||||
shape_bias = bias.get("shape")
|
||||
shape_bias = list(shape_bias)
|
||||
shape_bias = _get_bias(shape_bias)
|
||||
shape_bbias = bias.get("shape")
|
||||
shape_bbias = list(shape_bbias)
|
||||
shape_bbias = _get_bias(shape_bbias)
|
||||
|
||||
src_dtype = input_x1.get("dtype").lower()
|
||||
_shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)
|
||||
_shape_check(shape_aa, shape_bb, shape_bbias, src_dtype, trans_a, trans_b)
|
||||
|
||||
m_shape = shape_a[len(shape_a) - 2]
|
||||
km_shape = shape_a[len(shape_a) - 1]
|
||||
kn_shape = shape_b[len(shape_a) - 2]
|
||||
n_shape = shape_b[len(shape_a) - 1]
|
||||
m_shape = shape_aa[len(shape_aa) - 2]
|
||||
km_shape = shape_aa[len(shape_aa) - 1]
|
||||
kn_shape = shape_bb[len(shape_aa) - 2]
|
||||
n_shape = shape_bb[len(shape_aa) - 1]
|
||||
|
||||
if src_dtype == "float16":
|
||||
block_reduce = cce.BLOCK_REDUCE
|
||||
|
||||
block_in = cce.BLOCK_IN
|
||||
block_out = cce.BLOCK_OUT
|
||||
|
||||
if trans_a and km_shape == 1:
|
||||
block_in = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_a and m_shape == 1:
|
||||
block_in = cce.BLOCK_VECTOR
|
||||
|
||||
if trans_b and kn_shape == 1:
|
||||
block_out = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_b and n_shape == 1:
|
||||
block_out = cce.BLOCK_VECTOR
|
||||
block_in0, block_out0 = _get_block(trans_a, trans_b, m_shape, n_shape, km_shape, kn_shape)
|
||||
|
||||
if trans_a:
|
||||
shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in)
|
||||
shape_aa_tmp = (m_shape // block_reduce, km_shape // block_in0, block_reduce, block_in0)
|
||||
else:
|
||||
shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce)
|
||||
shape_aa_tmp = (m_shape // block_in0, km_shape // block_reduce, block_in0, block_reduce)
|
||||
|
||||
if trans_b:
|
||||
shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out)
|
||||
shape_bb_tmp = (kn_shape // block_out0, n_shape // block_reduce, block_reduce, block_out0)
|
||||
else:
|
||||
shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce)
|
||||
shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3])
|
||||
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
|
||||
shape_bb_tmp = (kn_shape // block_reduce, n_shape // block_out0, block_out0, block_reduce)
|
||||
shape_aa_tmp = (shape_aa_tmp[0], shape_aa_tmp[1], shape_aa_tmp[2], shape_aa_tmp[3])
|
||||
shape_bb_tmp = (shape_bb_tmp[0], shape_bb_tmp[1], shape_bb_tmp[2], shape_bb_tmp[3])
|
||||
|
||||
if util.get_product_version() == util.VERSION_MINI:
|
||||
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
|
||||
else:
|
||||
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
|
||||
input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_a_temp, name="left_matrix", scope=tik.scope_gm)
|
||||
input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_b_temp, name="right_matrix", scope=tik.scope_gm)
|
||||
res_matmul = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm)
|
||||
mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(input_x1, input_x2, 128)
|
||||
cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul,
|
||||
input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_aa_tmp, name="left_matrix", scope=tik.scope_gm)
|
||||
input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_bb_tmp, name="right_matrix", scope=tik.scope_gm)
|
||||
res_matmul0 = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm)
|
||||
mo_tile, ko_tile, no_tile, dig_opt = get_cus_tile_info(input_x1, input_x2, 128)
|
||||
cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul0,
|
||||
mo_tile=mo_tile, ko_tile=ko_tile, no_tile=no_tile,
|
||||
diag_opt=diag_opt, diag_size=128)
|
||||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul])
|
||||
diag_opt=dig_opt, diag_size=128)
|
||||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul0])
|
||||
return tik_instance
|
||||
|
||||
|
||||
|
|
|
@ -86,9 +86,9 @@ def cus_matmul_cube_fraczrightmul(input_x1, input_x2, input_x3, output_y=None, k
|
|||
input_x1 = tik_instance.Tensor("float16", input_x1_shape, name="left_matrix", scope=tik.scope_gm)
|
||||
input_x2 = tik_instance.Tensor("float16", input_x2_shape, name="right_matrix", scope=tik.scope_gm)
|
||||
input_x3 = tik_instance.Tensor("float32", input_x3_shape, name="matrix_max", scope=tik.scope_gm)
|
||||
resMatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm)
|
||||
cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resMatmul)
|
||||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul])
|
||||
resmatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm)
|
||||
cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resmatmul)
|
||||
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resmatmul])
|
||||
return tik_instance
|
||||
|
||||
|
||||
|
@ -177,7 +177,7 @@ def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3,
|
|||
core_m = block_idx // core_n_num
|
||||
core_n = block_idx % core_n_num
|
||||
res_l0c = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
|
||||
name="resMatmul_L0C", scope=tik.scope_cc)
|
||||
name="resmatmul_L0C", scope=tik.scope_cc)
|
||||
with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k:
|
||||
if diag_opt:
|
||||
k_idx = (core_n * loop_n_num + cc_n) * no_tile + thread_idx_k * ko_tile_inner
|
||||
|
@ -219,7 +219,7 @@ def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3,
|
|||
tik_instance.mmad(res_l0c, input_x1_l0a, input_x2_l0b, mo_tile * c0,
|
||||
ko_tile_inner * c0, no_tile * c0, 1)
|
||||
res_ub = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
|
||||
name="resMatmul_ub", scope=tik.scope_ubuf)
|
||||
name="resmatmul_ub", scope=tik.scope_ubuf)
|
||||
tik_instance.data_move(res_ub, res_l0c, 0, 1, no_tile * mo_tile, 0, 0)
|
||||
|
||||
input_3_local_ub = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="input_3_local_ub")
|
||||
|
|
|
@ -110,13 +110,13 @@ def cus_matmul_cube(input_x1, input_x2, bias=None, output_y=None, trans_a=False,
|
|||
result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a,
|
||||
format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias)
|
||||
|
||||
with tvm.target.cce():
|
||||
schedule = generic.auto_schedule(result)
|
||||
|
||||
tensor_list = [tensor_a, tensor_b, result]
|
||||
if shape_bias:
|
||||
tensor_list = [tensor_a, tensor_b, tensor_bias, result]
|
||||
|
||||
with tvm.target.cce():
|
||||
schedule = generic.auto_schedule(result)
|
||||
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
@ -163,19 +163,21 @@ def _get_block(shape_a, shape_b, trans_a, trans_b):
|
|||
km_shape = shape_a[len(shape_a) - 1]
|
||||
kn_shape = shape_b[len(shape_a) - 2]
|
||||
n_shape = shape_b[len(shape_a) - 1]
|
||||
block_in = cce.BLOCK_IN
|
||||
block_out = cce.BLOCK_OUT
|
||||
if trans_a and km_shape == 1:
|
||||
block_in = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_a and m_shape == 1:
|
||||
block_in = cce.BLOCK_VECTOR
|
||||
block_in = cce.BLOCK_IN
|
||||
|
||||
if trans_b and kn_shape == 1:
|
||||
block_out = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_b and n_shape == 1:
|
||||
block_out = cce.BLOCK_VECTOR
|
||||
|
||||
if trans_a and km_shape == 1:
|
||||
block_in = cce.BLOCK_VECTOR
|
||||
|
||||
if not trans_a and m_shape == 1:
|
||||
block_in = cce.BLOCK_VECTOR
|
||||
|
||||
return block_in, block_out
|
||||
|
||||
|
||||
|
|
|
@ -82,7 +82,6 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
|||
kernel_name="minmax_update_perchannel"):
|
||||
"""MinMaxUpdatePerChannel op"""
|
||||
x_shape = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
x_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
min_dtype = min_val.get("dtype")
|
||||
|
|
|
@ -39,12 +39,8 @@ def _get_tik():
|
|||
return tik_instance
|
||||
|
||||
|
||||
@op_info_register(cus_transpose02314_op_info)
|
||||
def cus_transpose02314(input_x, output, kernel_name="cus_transpose021354"):
|
||||
"""CusTranspose02314"""
|
||||
input_x_shape = input_x.get("shape")
|
||||
output_shape = output.get("shape")
|
||||
input_x_shape = tuple(input_x_shape)
|
||||
def _error_feedback(input_x_shape):
|
||||
"""error feedback"""
|
||||
support_shape = [(32, 128, 7, 7, 16),
|
||||
(32, 32, 7, 7, 16),
|
||||
(32, 32, 14, 14, 16),
|
||||
|
@ -60,6 +56,16 @@ def cus_transpose02314(input_x, output, kernel_name="cus_transpose021354"):
|
|||
if input_x_shape not in support_shape:
|
||||
raise RuntimeError("input_shape %s is not supported" % str(input_x_shape))
|
||||
|
||||
|
||||
@op_info_register(cus_transpose02314_op_info)
|
||||
def cus_transpose02314(input_x, output, kernel_name="cus_transpose021354"):
|
||||
"""CusTranspose02314"""
|
||||
input_x_shape = input_x.get("shape")
|
||||
output_shape = output.get("shape")
|
||||
input_x_shape = tuple(input_x_shape)
|
||||
|
||||
_error_feedback(input_x_shape)
|
||||
|
||||
tik_instance = _get_tik()
|
||||
|
||||
input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm)
|
||||
|
@ -302,6 +308,7 @@ def shape9(tik_instance, input_x, res, dtype):
|
|||
|
||||
def shape10(tik_instance, input_x, res, dtype):
|
||||
"""input shape (32, 32, 14, 14, 16)"""
|
||||
|
||||
def _inner_compute(split_index):
|
||||
input_x_ub = tik_instance.Tensor(dtype, [1, 32, 2, 14, 16], name="input_1_local_ub",
|
||||
scope=tik.scope_ubuf)
|
||||
|
@ -323,6 +330,7 @@ def shape10(tik_instance, input_x, res, dtype):
|
|||
|
||||
def shape11(tik_instance, input_x, res, dtype):
|
||||
"""input shape (32, 64, 14, 14, 16)"""
|
||||
|
||||
def _inner_compute(split_index, block_idx):
|
||||
input_x_ub = tik_instance.Tensor(dtype, [1, 64, 2, 14, 16], name="input_1_local_ub",
|
||||
scope=tik.scope_ubuf)
|
||||
|
|
Loading…
Reference in New Issue