!17785 clean code for thor

Merge pull request !17785 from melody/master
This commit is contained in:
i-robot 2021-06-07 14:34:35 +08:00 committed by Gitee
commit 346a35ba7d
18 changed files with 426 additions and 389 deletions

View File

@ -14,22 +14,3 @@
# ============================================================================
"""custom ops"""
from .batchnorm_fold import _batchnorm_fold_tbe
from .batchnorm_fold2 import _batchnorm_fold2_tbe
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
from .correction_mul import _correction_mul_tbe
from .correction_mul_grad import _correction_mul_grad_tbe
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
from .fake_learned_scale_quant_perlayer_grad import _fake_learned_scale_quant_perlayer_grad_d_tbe
from .fake_learned_scale_quant_perlayer_grad_reduce import _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe
from .fake_learned_scale_quant_perchannel import _fake_learned_scale_quant_perchannel_tbe
from .fake_learned_scale_quant_perchannel_grad import _fake_learned_scale_quant_perchannel_grad_d_tbe
from .fake_learned_scale_quant_perchannel_grad_reduce import _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe
from .fake_quant_perchannel import _fake_quant_perchannel_tbe
from .fake_quant_perchannel_grad import _fake_quant_perchannel_grad_tbe
from .fake_quant_perlayer import _fake_quant_per_layer_tbe
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
from .minmax_update_perlayer import _minmax_update_perlayer_tbe

View File

@ -1,20 +1,18 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
_basic
"""
from __future__ import absolute_import

View File

@ -43,6 +43,25 @@ def _get_flattern_shape(shape):
return (flattern_shape,)
def _error_feedback(input_shape):
"""error feedback"""
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
((36, 128, 128), (36, 128, 128), "float32", False, True),
((5, 128, 128), (5, 128, 128), "float32", False, True),
((18, 128, 128), (18, 128, 128), "float32", False, True),
((16, 128, 128), (16, 128, 128), "float32", False, True),
((9, 128, 128), (9, 128, 128), "float32", False, True),
((1, 64, 64), (1, 64, 64), "float32", False, True),
((1, 128, 128), (1, 128, 128), "float32", False, True),
((4, 128, 128), (4, 128, 128), "float32", False, True),
((2, 128, 128), (2, 128, 128), "float32", False, True),
((6, 128, 128), (6, 128, 128), "float32", False, True),
((24, 128, 128), (24, 128, 128), "float32", False, True),
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
if input_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
def _inner_matmul_new(tik_instance, dtype, input_info, res, res_index):
"""_inner_matmul_new"""
input1, input1_index, input2, input2_index = input_info
@ -100,9 +119,9 @@ def process_input_shape_640(input_shape, tik_instance, dtype, total_input, res):
"""process input shape of 640"""
input1, input2 = total_input
if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True):
with tik_instance.for_range(0, 30, block_num=30) as block_idx,\
tik_instance.for_range(0, 11) as cc1_db,\
tik_instance.for_range(0, 2, thread_num=2) as thread_idx,\
with tik_instance.for_range(0, 30, block_num=30) as block_idx, \
tik_instance.for_range(0, 11) as cc1_db, \
tik_instance.for_range(0, 2, thread_num=2) as thread_idx, \
tik_instance.if_scope(((((block_idx % 6) * 22) + (cc1_db * 2) + thread_idx) < 128)):
input_1_local_ub = tik_instance.Tensor(dtype, [128], name="input_1_local_ub",
scope=tik.scope_ubuf)
@ -157,17 +176,13 @@ def process_input_shape_1152(input_shape, tik_instance, dtype, total_input, res)
input2_index = (block_idx // 3) * 16384
res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128
input_info = input1, input1_index, input2, input2_index
_inner_matmul_new(tik_instance, dtype,
input_info,
res, res_index)
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
with tik_instance.if_scope((block_idx % 3) < 2):
input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128
input2_index = (block_idx // 3) * 16384
res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128
input_info = input1, input1_index, input2, input2_index
_inner_matmul_new(tik_instance, dtype,
input_info,
res, res_index)
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
@op_info_register(cus_batchmatmul_op_info)
@ -185,23 +200,9 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % (
dtype, input_x2.get("dtype").lower()))
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b)
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
((36, 128, 128), (36, 128, 128), "float32", False, True),
((5, 128, 128), (5, 128, 128), "float32", False, True),
((18, 128, 128), (18, 128, 128), "float32", False, True),
((16, 128, 128), (16, 128, 128), "float32", False, True),
((9, 128, 128), (9, 128, 128), "float32", False, True),
((1, 64, 64), (1, 64, 64), "float32", False, True),
((1, 128, 128), (1, 128, 128), "float32", False, True),
((4, 128, 128), (4, 128, 128), "float32", False, True),
((2, 128, 128), (2, 128, 128), "float32", False, True),
((6, 128, 128), (6, 128, 128), "float32", False, True),
((24, 128, 128), (24, 128, 128), "float32", False, True),
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
if input_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
# if not transpose_a and transpose_b:
_error_feedback(input_shape)
batch, m, k = x1_shape
input1_shape = _get_flattern_shape(x1_shape)
@ -215,14 +216,13 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True):
with tik_instance.for_range(0, 18, block_num=18) as block_idx, \
tik_instance.for_range(0, 2) as cc0,\
tik_instance.for_range(0, 2) as cc0, \
tik_instance.for_range(0, 128, thread_num=2) as cc1:
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
input2_index = block_idx * 32768 + cc0 * 16384
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
input_info = input1, input1_index, input2, input2_index
_inner_matmul_new(tik_instance, dtype,
input_info, res, res_index)
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
total_input = input1, input2
process_input_shape_640(input_shape, tik_instance, dtype, total_input, res)
@ -234,20 +234,18 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
input2_index = block_idx * 16384
res_index = block_idx * 16384 + cc0 * 128
input_info = input1, input1_index, input2, input2_index
_inner_matmul_new(tik_instance, dtype,
input_info, res, res_index)
_inner_matmul_new(tik_instance, dtype, input_info, res, res_index)
process_input_shape_1152(input_shape, tik_instance, dtype, total_input, res)
if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True):
with tik_instance.for_range(0, 32, block_num=32) as block_idx,\
with tik_instance.for_range(0, 32, block_num=32) as block_idx, \
tik_instance.for_range(0, 2, thread_num=2) as cc0:
input1_index = block_idx * 128 + cc0 * 64
input2_index = 0
res_index = block_idx * 128 + cc0 * 64
input_info = input1, input1_index, input2, input2_index
_inner_matmul_new_1_64_32_64(tik_instance, dtype,
input_info,
_inner_matmul_new_1_64_32_64(tik_instance, dtype, input_info,
res, res_index)
input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True),
@ -257,14 +255,12 @@ def cus_batch_matmul(input_x1, input_x2, output, transpose_a=False,
((8, 128, 128), (8, 128, 128), "float32", False, True),
((16, 128, 128), (16, 128, 128), "float32", False, True),
((24, 128, 128), (24, 128, 128), "float32", False, True),
((32, 128, 128), (32, 128, 128), 'float32', False, True)
]
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
if input_shape in input_shape_list:
block_num = 32
block_num, thread_num = 32, 2
input1_unit_size = 128
input2_unint_size = 128 * 128
block_process_ele_num = (batch * m * k) // block_num
thread_num = 2
loop_time = (batch * m * k) // block_num // input1_unit_size
with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx, \
tik_instance.for_range(0, loop_time, thread_num=thread_num) as cc0:

View File

@ -28,11 +28,11 @@ batch_norm_op_info = TBERegOp("BatchNormFoldD") \
.compute_cost(10) \
.kernel_name("batchnorm_fold") \
.partial_flag(True) \
.attr("momentum", "optional", "float", "all", "0.9") \
.attr("epsilon", "optional", "float", "all", "0.00001") \
.attr("is_training", "optional", "bool", "all", "true") \
.attr("freeze_bn", "optional", "int", "all", "0") \
.attr("format", "optional", "str", "all", "NCHW") \
.attr("momentum", "optional", "float", "all") \
.attr("epsilon", "optional", "float", "all") \
.attr("is_training", "optional", "bool", "all") \
.attr("freeze_bn", "optional", "int", "all") \
.attr("format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "x_sum", False, "required", "all") \
.input(2, "x_square_sum", False, "required", "all") \
@ -57,6 +57,43 @@ def _batchnorm_fold_tbe():
return
def _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon):
"""_batchnorm_fold_compute"""
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
# compute the mean of x
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
# compute the variance of x
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
if num == 1:
batch_var_scaler = 0.0
else:
batch_var_scaler = float(num) / (num - 1)
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
factor = 1.0 - momentum
factor_reverse = momentum
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
y = te.lang.cce.vadds(x_input, 0.0)
running_mean = te.lang.cce.vadds(mean, 0.0)
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
return res
@util.check_input_type(dict, dict, dict, dict, dict,
dict, dict, dict, dict, dict, dict, dict,
float, float, bool, int, str, str)
@ -108,39 +145,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
# compute the mean of x
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
# compute the variance of x
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
if num == 1:
batch_var_scaler = 0.0
else:
batch_var_scaler = float(num) / (num - 1)
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
factor = 1.0 - momentum
factor_reverse = momentum
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
y = te.lang.cce.vadds(x_input, 0.0)
running_mean = te.lang.cce.vadds(mean, 0.0)
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
res = _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon)
with tvm.target.cce():
sch = generic.auto_schedule(res)
config = {"name": kernel_name,

View File

@ -21,6 +21,7 @@ from te.platform.cce_build import build_config
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
SHAPE_SIZE_LIMIT = 2147483648
@ -81,9 +82,8 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in check_list:
if not inp_dtype in ["float16", "float32"]:
raise RuntimeError("Dtype of input only support float16, float32")
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
@ -100,7 +100,7 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
te.lang.cce.cce_build_code(sch, config)
return
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
with build_config:
tvm.build(sch, tensor_list, "cce", name=kernel_name)

View File

@ -95,12 +95,12 @@ def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
dtype_mean = d_batch_mean.get("dtype").lower()
if format_data == "NC1HWC0":
if len(shape_x) != 5:
raise RuntimeError("batchnorm_fold only support shape 5D"
raise RuntimeError("batchnorm_fold grad only support shape 5D"
"when input format is NC1HWC0")
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
elif format_data == "NCHW":
if len(shape_x) < 2 or len(shape_x) > 4:
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
raise RuntimeError("batchnorm_fold grad only support shape 2D to 4D")
if shape_x[1] != shape_mean[0]:
raise RuntimeError("data_format is NCHW, shape_bias must"
"be equal to the second axis of shape_x")

View File

@ -66,9 +66,8 @@ def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correctio
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape)
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
check_list = ["float16", "float32"]
inp_dtype = x.get("dtype").lower()
if not inp_dtype in check_list:
if not inp_dtype in ["float16", "float32"]:
raise RuntimeError("Dtype of input only support float16, float32")
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)

View File

@ -86,11 +86,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
return res
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
kernel_name="fake_quant_perchannel"):
"""Get and check fake_quant_perchannel parameters"""
x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format")
@ -120,15 +118,25 @@ def fake_quant_perchannel(x, min_val, max_val, y,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
return x_shape, shape_c, x_dtype
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
x_shape, shape_c, x_dtype = fake_quant_perchannel_param(x, min_val, max_val,
channel_axis, kernel_name)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)

View File

@ -110,11 +110,9 @@ def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, qua
return res
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""Get and check FakeQuantPerChannelGrad parameters"""
x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format")
@ -144,6 +142,18 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
return x_shape, shape_c, x_dtype
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
@ -153,10 +163,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape")
x_shape, shape_c, x_dtype = fake_quant_perchannel_grad_param(x, min_val, max_val,
channel_axis, kernel_name)
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)

View File

@ -54,6 +54,25 @@ def _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res
return tik_instance, res
def _error_feedback(input_info):
"""error feedback"""
support_shape = [((1, 128, 128), "float32"),
((2, 128, 128), "float32"),
((4, 128, 128), "float32"),
((8, 128, 128), "float32"),
((16, 128, 128), "float32"),
((5, 128, 128), "float32"),
((9, 128, 128), "float32"),
((18, 128, 128), "float32"),
((36, 128, 128), "float32"),
((32, 128, 128), "float32"),
((1, 64, 64), "float32"),
((32, 64), "float32")
]
if input_info not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_info))
def shape0(tik_instance, input_x_shape, input_x, res):
"""shape0"""
total_elements0 = 1
@ -482,25 +501,13 @@ def cus_fused_abs_max1(input_x, output, origin_shape=None, kernel_name="cus_fuse
tik_instance = _get_tik_instance()
support_shape = [((1, 128, 128), "float32"),
((2, 128, 128), "float32"),
((4, 128, 128), "float32"),
((8, 128, 128), "float32"),
((16, 128, 128), "float32"),
((5, 128, 128), "float32"),
((9, 128, 128), "float32"),
((18, 128, 128), "float32"),
((36, 128, 128), "float32"),
((32, 128, 128), "float32"),
((1, 64, 64), "float32"),
((32, 64), "float32")
]
ori_shape = tuple(origin_shape)
input_info = (tuple(input_x_shape), dtype)
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
if input_info not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_info))
_error_feedback(input_info)
if input_info == ((1, 128, 128), "float32"):
tik_instance, res = shape0(tik_instance, input_x_shape, input_x, res)
elif input_info == ((2, 128, 128), "float32"):
@ -534,8 +541,6 @@ def cus_fused_abs_max1(input_x, output, origin_shape=None, kernel_name="cus_fuse
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.data_move(res[0], input_x_ub, 0, 1, 1, 0, 0)
else:
raise RuntimeError("UnSupportedShape")
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res])
return tik_instance

View File

@ -35,7 +35,7 @@ cus_img2col_info = TBERegOp("CusImg2Col") \
def shape56_0(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1))"""
"""input_shape is ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h, l1_w, jump_stride, repeat_mode = 56, 56, 1, 1
@ -62,7 +62,7 @@ def shape56_0(tik_instance, input_x, res, input_shape, shape_info):
def shape56_1(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2))"""
"""input_shape is ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h, l1_w, jump_stride, repeat_mode = 56, 56, 1, 1
@ -90,7 +90,7 @@ def shape56_1(tik_instance, input_x, res, input_shape, shape_info):
def shape56_2(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 56, 0, 1, 1
@ -114,7 +114,7 @@ def shape56_2(tik_instance, input_x, res, input_shape, shape_info):
def shape56_3(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 56, 0, 1, 1
@ -142,7 +142,7 @@ def shape56_3(tik_instance, input_x, res, input_shape, shape_info):
def shape56_4(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2))"""
"""input_shape is ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 56, 0, 1, 1
@ -171,7 +171,7 @@ def shape56_4(tik_instance, input_x, res, input_shape, shape_info):
def shape28_0(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1))"""
"""input_shape is ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
@ -199,7 +199,7 @@ def shape28_0(tik_instance, input_x, res, input_shape, shape_info):
def shape28_1(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2))"""
"""input_shape is ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
@ -238,7 +238,7 @@ def shape28_1(tik_instance, input_x, res, input_shape, shape_info):
def shape28_2(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2))"""
"""input_shape is ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
@ -286,7 +286,7 @@ def shape28_2(tik_instance, input_x, res, input_shape, shape_info):
def shape28_3(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
@ -314,7 +314,7 @@ def shape28_3(tik_instance, input_x, res, input_shape, shape_info):
def shape28_4(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, jump_stride, repeat_mode = 28, 28, 1, 1
@ -342,7 +342,7 @@ def shape28_4(tik_instance, input_x, res, input_shape, shape_info):
def shape14_0(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1))"""
"""input_shape is ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
@ -382,7 +382,7 @@ def shape14_0(tik_instance, input_x, res, input_shape, shape_info):
def shape14_1(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2))"""
"""input_shape is ((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
@ -419,7 +419,7 @@ def shape14_1(tik_instance, input_x, res, input_shape, shape_info):
def shape14_2(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2))"""
"""input_shape is ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
@ -454,7 +454,7 @@ def shape14_2(tik_instance, input_x, res, input_shape, shape_info):
def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h, l1_w, jump_stride, repeat_mode = 14, 14, 1, 1
@ -476,7 +476,7 @@ def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i:
rep = 13
c1_index = 0
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h = 0, 0, 0, 0
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h, rep, c1_index = 0, 0, 0, 0, 13, 0
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[3328 * (i - eeb1 * 16)],
input_1_1_local_l1[3136 * i],
pad, l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
@ -492,9 +492,7 @@ def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
with tik_instance.for_range(0, 2) as eeb1:
with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i:
rep = 13
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h = 0, 0, 0, 0
c1_index = 0
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h, rep, c1_index = 0, 0, 0, 0, 13, 0
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[3328 * (i - eeb1 * 16)],
input_1_2_local_l1[3136 * i],
pad, l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
@ -511,7 +509,7 @@ def shape14_3(tik_instance, input_x, res, input_shape, shape_info):
def shape14_4(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h = 14
@ -551,7 +549,7 @@ def shape14_4(tik_instance, input_x, res, input_shape, shape_info):
def shape7_0(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1))"""
"""input_shape is ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [1, 1, 1, 1]
l1_h = 7
@ -576,36 +574,34 @@ def shape7_0(tik_instance, input_x, res, input_shape, shape_info):
left_top_h = -1
c1_index = 0
with tik_instance.for_range(0, 32) as i:
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[1024 * i], input_1_1_local_l1[784 * i],
pad, l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
left_top_w, left_top_h, stride_w, stride_h, filter_w,
filter_h, dilation_filter_w, dilation_filter_h,
jump_stride, repeat_mode, rep)
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub[1024 * i],
input_1_1_local_l1[784 * i], pad, l1_h, l1_w, c1_index,
fetch_filter_w, fetch_filter_h, left_top_w, left_top_h,
stride_w, stride_h, filter_w, filter_h, dilation_filter_w,
dilation_filter_h, jump_stride, repeat_mode, rep)
with tik_instance.for_range(0, 32) as i:
tik_instance.data_move(input_1_2_fractal_l1_local_ub[i * 49 * 16],
input_1_1_fractal_l1_local_ub[i * 1024], 0, 1, 49, 0, 0)
with tik_instance.for_range(0, 98) as i:
tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0], input_1_2_fractal_l1_local_ub[256 * i],
0, 1, 16, 0, 0)
tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0],
input_1_2_fractal_l1_local_ub[256 * i], 0, 1, 16, 0, 0)
return tik_instance, res
def shape7_1(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h = 7
l1_w = 7
jump_stride = 1
repeat_mode = 1
l1_h, l1_w, jump_stride, repeat_mode = 7, 7, 1, 1
with tik_instance.for_range(0, 32, block_num=32) as block_index:
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
name="input_1_2_fractal_l1_local_ub")
input_1_1_local_l1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf,
name="input_1_1_local_l1")
input_1_1_fractal_l1_local_ub = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf,
name="input_1_1_fractal_l1_local_ub")
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
name="input_1_2_fractal_l1_local_ub")
with tik_instance.for_range(0, 4) as eeb0, tik_instance.for_range(0, 32) as i:
tik_instance.data_move(input_1_1_local_l1[i * 784], input_x[i, eeb0 + block_index * 4, 0, 0, 0], 0,
1, 49, 0, 0)
@ -633,21 +629,21 @@ def shape7_1(tik_instance, input_x, res, input_shape, shape_info):
def shape7_2(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape == ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
"""input_shape is ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [0, 0, 0, 0]
l1_h = 7
l1_w = 7
c1_index = 0
jump_stride = 1
repeat_mode = 1
jump_stride, repeat_mode = 1, 1
with tik_instance.for_range(0, 32, block_num=32) as block_index:
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
name="input_1_2_fractal_l1_local_ub")
input_1_1_local_l1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf,
name="input_1_1_local_l1")
input_1_1_fractal_l1_local_ub = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf,
name="input_1_1_fractal_l1_local_ub")
input_1_2_fractal_l1_local_ub = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf,
name="input_1_2_fractal_l1_local_ub")
with tik_instance.for_range(0, 32) as i:
tik_instance.data_move(input_1_1_local_l1[i * 784], input_x[i, block_index, 0, 0, 0], 0, 1, 49, 0, 0)
@ -673,14 +669,10 @@ def shape7_2(tik_instance, input_x, res, input_shape, shape_info):
def height224_width224(tik_instance, input_x, res, input_shape, shape_info):
"""input_shape = ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2))"""
"""input_shape is ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2))"""
stride_w, stride_h, filter_w, filter_h, dilation_filter_w, dilation_filter_h = shape_info
pad = [3, 3, 3, 3]
l1_h = 56
l1_w = 224
c1_index = 0
jump_stride = 1
repeat_mode = 1
l1_h, l1_w, c1_index, jump_stride, repeat_mode = 56, 224, 0, 1, 1
with tik_instance.for_range(0, 32, block_num=32) as block_index:
input_1_1_local_l1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf,
name="input_1_1_local_l1")
@ -717,11 +709,10 @@ def height224_width224(tik_instance, input_x, res, input_shape, shape_info):
left_top_h = 1 + ((55 - temp - (-3 + eeb)) // 2 - 29) * 2
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub, input_1_1_local_l1, pad,
l1_h, l1_w, c1_index, fetch_filter_w, fetch_filter_h,
left_top_w, left_top_h, stride_w, stride_h, filter_w,
filter_h, dilation_filter_w, dilation_filter_h, jump_stride,
repeat_mode, rep)
tik_instance.load3dv1(input_1_1_fractal_l1_local_ub, input_1_1_local_l1, pad, l1_h, l1_w,
c1_index, fetch_filter_w, fetch_filter_h, left_top_w, left_top_h,
stride_w, stride_h, filter_w, filter_h, dilation_filter_w,
dilation_filter_h, jump_stride, repeat_mode, rep)
with tik_instance.for_range(0, rep) as cc1:
tik_instance.data_move(
res[cc0 + eeb * 7, cc1 + rep_prefix + (eeb0 - 1) * rep + 784 * block_index, 0, 0],

View File

@ -46,6 +46,7 @@ matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \
.dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.get_op_info()
def shape_gen1(input_x1, input_x2, output_y, kernel_name, trans_a, trans_b):
"""shape gen1"""
shape_a = input_x1.get("ori_shape")
@ -94,6 +95,7 @@ def shape_gen1(input_x1, input_x2, output_y, kernel_name, trans_a, trans_b):
trans_b = bool(1 - trans_b)
return shape_a, shape_b, trans_a, trans_b, shape_output
def shape_gen2(bias, input_x1, output_y, shape_a, shape_b, trans_a, trans_b):
"""shape gen2"""
shape_bias = ()
@ -144,6 +146,7 @@ def shape_gen2(bias, input_x1, output_y, shape_a, shape_b, trans_a, trans_b):
format_b = "FRACTAL_NZ"
return shape_a_temp, format_a, shape_b_temp, format_b, shape_bias, src_dtype, dst_dtype
def core(shape_a_temp, shape_b_temp, shape_output, kernel_name):
"""core func"""
if util.get_product_version() == util.VERSION_MINI:
@ -206,6 +209,7 @@ def core(shape_a_temp, shape_b_temp, shape_output, kernel_name):
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resmatmul])
return tik_instance
@op_info_register(matmul_cube_dense_left_op_info)
def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, trans_a=False, trans_b=False,
kernel_name="cus_matmul_cube_dense_left"):

View File

@ -68,106 +68,133 @@ def cus_matmul_cube_dense_right(input_x1, input_x2, input_x3, output_y=None,
core_m_idx = block_index // 16
core_n_idx = block_index % 16
matrix_max_scalar = tik_instance.Scalar("float32")
matrix_max_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf,
name="matrix_max_local_UB")
tik_instance.data_move(matrix_max_local_UB, input_x3, 0, 1, 1, 0, 0)
matrix_max_scalar.set_as(matrix_max_local_UB[0])
resmatmul_local_ub = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
name="resmatmul_local_ub")
resmatmul_local_ub1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf,
name="resmatmul_local_ub1")
resmatmul_local_ub_local_l0c = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
name="resmatmul_local_ub_local_l0c")
resmatmul_local_ub_local_l0c1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc,
name="resmatmul_local_ub_local_l0c1")
input_1_local_l1_local_L0A = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
name="input_1_local_l1_local_L0A")
input_2_local_l1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_l1")
input_2_local_l11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_l11")
input_1_local_l1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
name="input_1_local_l1")
input_1_local_l11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf,
name="input_1_local_l11")
input_2_local_l1_local_l0b = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_l1_local_l0b")
input_2_local_l1_local_l0b1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_l1_local_l0b1")
matrix_max_local_ub = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf,
name="matrix_max_local_ub")
tik_instance.data_move(matrix_max_local_ub, input_x3, 0, 1, 1, 0, 0)
matrix_max_scalar.set_as(matrix_max_local_ub[0])
with tik_instance.if_scope(core_m_idx == 0):
with tik_instance.for_range(0, 2) as cc1:
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8,
128, 1920, 0)
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256,
752, 0)
with tik_instance.for_range(0, 8) as cc10:
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0,
8, 8, 0, True)
with tik_instance.for_range(0, 16) as cc101:
tik_instance.load2dv1(input_1_local_l1_local_L0A[cc101 * 2048], input_1_local_l1[cc101 * 256],
0, 8, 16, 0, False)
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_L0A,
input_2_local_l1_local_l0b, 256, 128, 128, 0)
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64],
matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64],
matrix_max_scalar, 2, 1, 1, 8, 8)
tik_instance.data_move(resmatmul[core_n_idx * 129024 + cc1 * 4096], resmatmul_local_ub, 0, 8, 512,
0, 1504)
tik_instance, resmatmul = _update_tik1(tik_instance, input_x1, input_x2, core_n_idx,
resmatmul, matrix_max_scalar)
with tik_instance.else_scope():
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
1920, 0)
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0)
with tik_instance.for_range(0, 8) as cc10:
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0, 8,
8, 0, True)
with tik_instance.for_range(0, 16) as cc101:
tik_instance.load2dv1(input_1_local_l1_local_L0A[cc101 * 2048], input_1_local_l1[cc101 * 256],
0, 8, 16, 0, False)
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_L0A, input_2_local_l1_local_l0b,
256, 128, 128, 0)
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64], matrix_max_scalar,
255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64], matrix_max_scalar,
2, 1, 1, 8, 8)
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 2 * 4096], resmatmul_local_ub, 0, 8, 512, 0,
1504)
tik_instance.data_move(input_2_local_l11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
1920, 0)
tik_instance.data_move(input_1_local_l11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0)
with tik_instance.for_range(0, 8) as cc102:
tik_instance.load2dv1(input_2_local_l1_local_l0b1[cc102 * 2048], input_2_local_l11[cc102 * 256], 0,
8, 8, 0, True)
with tik_instance.for_range(0, 16) as cc103:
tik_instance.load2dv1(input_1_local_l1_local_L0A[cc103 * 2048], input_1_local_l11[cc103 * 256], 0,
8, 15, 0, False)
tik_instance.mmad(resmatmul_local_ub_local_l0c1, input_1_local_l1_local_L0A,
input_2_local_l1_local_l0b1, 240, 128, 128, 0)
tik_instance.data_move(resmatmul_local_ub1, resmatmul_local_ub_local_l0c1, 0, 1, 120, 0, 0)
tik_instance.vmuls(64, resmatmul_local_ub1, resmatmul_local_ub1, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub1[255 * 64], resmatmul_local_ub1[255 * 64], matrix_max_scalar,
225, 1, 1, 8, 8)
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 12288], resmatmul_local_ub1, 0, 8, 480, 0, 1536)
tik_instance, resmatmul = _update_tik2(tik_instance, input_x1, input_x2, core_n_idx,
resmatmul, matrix_max_scalar)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resmatmul])
return tik_instance
return None
def _update_tik1(tik_instance, input_x1, input_x2, core_n_idx, resmatmul, matrix_max_scalar):
"""_update_tik1"""
resmatmul_local_ub = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
name="resmatmul_local_ub")
resmatmul_local_ub_local_l0c = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
name="resmatmul_local_ub_local_l0c")
input_1_local_l1_local_l0a = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
name="input_1_local_l1_local_l0a")
input_2_local_l1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_l1")
input_1_local_l1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
name="input_1_local_l1")
input_2_local_l1_local_l0b = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_l1_local_l0b")
with tik_instance.for_range(0, 2) as cc1:
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8,
128, 1920, 0)
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256,
752, 0)
with tik_instance.for_range(0, 8) as cc10:
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0,
8, 8, 0, True)
with tik_instance.for_range(0, 16) as cc101:
tik_instance.load2dv1(input_1_local_l1_local_l0a[cc101 * 2048], input_1_local_l1[cc101 * 256],
0, 8, 16, 0, False)
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_l0a,
input_2_local_l1_local_l0b, 256, 128, 128, 0)
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64],
matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64],
matrix_max_scalar, 2, 1, 1, 8, 8)
tik_instance.data_move(resmatmul[core_n_idx * 129024 + cc1 * 4096], resmatmul_local_ub, 0, 8, 512,
0, 1504)
return tik_instance, resmatmul
def _update_tik2(tik_instance, input_x1, input_x2, core_n_idx, resmatmul, matrix_max_scalar):
"""_update_tik2"""
resmatmul_local_ub = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
name="resmatmul_local_ub")
resmatmul_local_ub1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf,
name="resmatmul_local_ub1")
resmatmul_local_ub_local_l0c = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
name="resmatmul_local_ub_local_l0c")
resmatmul_local_ub_local_l0c1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc,
name="resmatmul_local_ub_local_l0c1")
input_1_local_l1_local_l0a = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
name="input_1_local_l1_local_l0a")
input_2_local_l1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_l1")
input_2_local_l11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_l11")
input_1_local_l1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
name="input_1_local_l1")
input_1_local_l11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf,
name="input_1_local_l11")
input_2_local_l1_local_l0b = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_l1_local_l0b")
input_2_local_l1_local_l0b1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_l1_local_l0b1")
tik_instance.data_move(input_2_local_l1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
1920, 0)
tik_instance.data_move(input_1_local_l1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0)
with tik_instance.for_range(0, 8) as cc10:
tik_instance.load2dv1(input_2_local_l1_local_l0b[cc10 * 2048], input_2_local_l1[cc10 * 256], 0, 8,
8, 0, True)
with tik_instance.for_range(0, 16) as cc101:
tik_instance.load2dv1(input_1_local_l1_local_l0a[cc101 * 2048], input_1_local_l1[cc101 * 256],
0, 8, 16, 0, False)
tik_instance.mmad(resmatmul_local_ub_local_l0c, input_1_local_l1_local_l0a, input_2_local_l1_local_l0b,
256, 128, 128, 0)
tik_instance.data_move(resmatmul_local_ub, resmatmul_local_ub_local_l0c, 0, 1, 128, 0, 0)
tik_instance.vmuls(64, resmatmul_local_ub, resmatmul_local_ub, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[255 * 64], resmatmul_local_ub[255 * 64], matrix_max_scalar,
255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub[510 * 64], resmatmul_local_ub[510 * 64], matrix_max_scalar,
2, 1, 1, 8, 8)
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 2 * 4096], resmatmul_local_ub, 0, 8, 512, 0,
1504)
tik_instance.data_move(input_2_local_l11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
1920, 0)
tik_instance.data_move(input_1_local_l11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0)
with tik_instance.for_range(0, 8) as cc102:
tik_instance.load2dv1(input_2_local_l1_local_l0b1[cc102 * 2048], input_2_local_l11[cc102 * 256], 0,
8, 8, 0, True)
with tik_instance.for_range(0, 16) as cc103:
tik_instance.load2dv1(input_1_local_l1_local_l0a[cc103 * 2048], input_1_local_l11[cc103 * 256], 0,
8, 15, 0, False)
tik_instance.mmad(resmatmul_local_ub_local_l0c1, input_1_local_l1_local_l0a,
input_2_local_l1_local_l0b1, 240, 128, 128, 0)
tik_instance.data_move(resmatmul_local_ub1, resmatmul_local_ub_local_l0c1, 0, 1, 120, 0, 0)
tik_instance.vmuls(64, resmatmul_local_ub1, resmatmul_local_ub1, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resmatmul_local_ub1[255 * 64], resmatmul_local_ub1[255 * 64], matrix_max_scalar,
225, 1, 1, 8, 8)
tik_instance.data_move(resmatmul[core_n_idx * 129024 + 12288], resmatmul_local_ub1, 0, 8, 480, 0, 1536)
return tik_instance, resmatmul

View File

@ -44,6 +44,31 @@ matmul_cube_fracz_left_cast_op_info = TBERegOp("CusMatMulCubeFraczLeftCast") \
.get_op_info()
def _clip_num(num):
"""clip number"""
if num == 0:
num = 1
return num
def _get_block(trans_a, trans_b, m_shape, n_shape, km_shape, kn_shape):
"""_get_block"""
block_in0 = cce.BLOCK_IN
block_out0 = cce.BLOCK_OUT
if trans_a and km_shape == 1:
block_in0 = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in0 = cce.BLOCK_VECTOR
if trans_b and kn_shape == 1:
block_out0 = cce.BLOCK_VECTOR
if not trans_b and n_shape == 1:
block_out0 = cce.BLOCK_VECTOR
return block_in0, block_out0
@op_info_register(matmul_cube_fracz_left_cast_op_info)
def cus_matmul_cube_fraczleftcast(input_x1, input_x2, bias=None, output_y=None, trans_a=False, trans_b=False,
kernel_name="cus_matmul_cube_fraczleftcast"):
@ -52,132 +77,113 @@ def cus_matmul_cube_fraczleftcast(input_x1, input_x2, bias=None, output_y=None,
data with fractal format.
Parameters:
shape_a: list or tuple
shape_aa: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
shape_bb: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
and shape_aa, shape_bb must be 2 dims
src_dtype: str
The data type of input, support "float32", "float16"
dst_dtype: str
The data type of output, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
If True, shape_aa == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
If True, shape_bb == transposed before multiplication
is_fractal: bool
If True, the input data format of a and b must be fractal format
shape_bias: list or tuple
shape_bbias: list or tuple
Shape of bias, only support the input data format with ND
Returns
-------
None
"""
shape_a = input_x1.get("ori_shape")
shape_b = input_x2.get("ori_shape")
print("============")
print(input_x1.get("format"), input_x2.get("format"))
print(shape_a, shape_b)
print("============")
shape_aa = input_x1.get("ori_shape")
shape_bb = input_x2.get("ori_shape")
if input_x2.get("format") == "FRACTAL_Z":
n, c, h, w = shape_b
n, c, h, w = shape_bb
c0 = 16
c1 = c // c0
if c1 == 0:
c1 = 1
shape_b = [n, c1 * h * w * c0]
shape_a = [n, n]
c1 = _clip_num(c1)
shape_bb = [n, c1 * h * w * c0]
shape_aa = [n, n]
if input_x1.get("format") == "FRACTAL_Z":
n, c, h, w = shape_a
n, c, h, w = shape_aa
c0 = 16
c1 = c // c0
if c1 == 0:
c1 = 1
shape_a = [n, c1 * h * w * c0]
shape_b = [c1 * h * w * c0, c1 * h * w * c0]
c1 = _clip_num(c1)
shape_aa = [n, c1 * h * w * c0]
shape_bb = [c1 * h * w * c0, c1 * h * w * c0]
if input_x2.get("format") == "FRACTAL_NZ":
shape_a = [shape_b[0], shape_b[0]]
shape_aa = [shape_bb[0], shape_bb[0]]
if input_x1.get("format") == "FRACTAL_NZ":
shape_b = [shape_a[1], shape_a[1]]
shape_bb = [shape_aa[1], shape_aa[1]]
shape_a = list(shape_a)
shape_b = list(shape_b)
shape_aa = list(shape_aa)
shape_bb = list(shape_bb)
shape_a = _get_input_shape(shape_a)
shape_b = _get_input_shape(shape_b)
shape_aa = _get_input_shape(shape_aa)
shape_bb = _get_input_shape(shape_bb)
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
util.check_shape_rule(shape_aa)
util.check_shape_rule(shape_bb)
util.check_shape_size(shape_aa, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_bb, SHAPE_SIZE_LIMIT)
shape_a = [shape_a[1], shape_a[0]]
shape_aa = [shape_aa[1], shape_aa[0]]
trans_a = bool(1 - trans_a)
shape_b = [shape_b[1], shape_b[0]]
shape_bb = [shape_bb[1], shape_bb[0]]
trans_b = bool(1 - trans_b)
shape_bias = ()
shape_bbias = ()
if bias is not None and bool(bias):
shape_bias = bias.get("shape")
shape_bias = list(shape_bias)
shape_bias = _get_bias(shape_bias)
shape_bbias = bias.get("shape")
shape_bbias = list(shape_bbias)
shape_bbias = _get_bias(shape_bbias)
src_dtype = input_x1.get("dtype").lower()
_shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)
_shape_check(shape_aa, shape_bb, shape_bbias, src_dtype, trans_a, trans_b)
m_shape = shape_a[len(shape_a) - 2]
km_shape = shape_a[len(shape_a) - 1]
kn_shape = shape_b[len(shape_a) - 2]
n_shape = shape_b[len(shape_a) - 1]
m_shape = shape_aa[len(shape_aa) - 2]
km_shape = shape_aa[len(shape_aa) - 1]
kn_shape = shape_bb[len(shape_aa) - 2]
n_shape = shape_bb[len(shape_aa) - 1]
if src_dtype == "float16":
block_reduce = cce.BLOCK_REDUCE
block_in = cce.BLOCK_IN
block_out = cce.BLOCK_OUT
if trans_a and km_shape == 1:
block_in = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in = cce.BLOCK_VECTOR
if trans_b and kn_shape == 1:
block_out = cce.BLOCK_VECTOR
if not trans_b and n_shape == 1:
block_out = cce.BLOCK_VECTOR
block_in0, block_out0 = _get_block(trans_a, trans_b, m_shape, n_shape, km_shape, kn_shape)
if trans_a:
shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in)
shape_aa_tmp = (m_shape // block_reduce, km_shape // block_in0, block_reduce, block_in0)
else:
shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce)
shape_aa_tmp = (m_shape // block_in0, km_shape // block_reduce, block_in0, block_reduce)
if trans_b:
shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out)
shape_bb_tmp = (kn_shape // block_out0, n_shape // block_reduce, block_reduce, block_out0)
else:
shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce)
shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3])
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
shape_bb_tmp = (kn_shape // block_reduce, n_shape // block_out0, block_out0, block_reduce)
shape_aa_tmp = (shape_aa_tmp[0], shape_aa_tmp[1], shape_aa_tmp[2], shape_aa_tmp[3])
shape_bb_tmp = (shape_bb_tmp[0], shape_bb_tmp[1], shape_bb_tmp[2], shape_bb_tmp[3])
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_a_temp, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_b_temp, name="right_matrix", scope=tik.scope_gm)
res_matmul = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm)
mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(input_x1, input_x2, 128)
cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul,
input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_aa_tmp, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_bb_tmp, name="right_matrix", scope=tik.scope_gm)
res_matmul0 = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm)
mo_tile, ko_tile, no_tile, dig_opt = get_cus_tile_info(input_x1, input_x2, 128)
cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul0,
mo_tile=mo_tile, ko_tile=ko_tile, no_tile=no_tile,
diag_opt=diag_opt, diag_size=128)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul])
diag_opt=dig_opt, diag_size=128)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul0])
return tik_instance

View File

@ -86,9 +86,9 @@ def cus_matmul_cube_fraczrightmul(input_x1, input_x2, input_x3, output_y=None, k
input_x1 = tik_instance.Tensor("float16", input_x1_shape, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor("float16", input_x2_shape, name="right_matrix", scope=tik.scope_gm)
input_x3 = tik_instance.Tensor("float32", input_x3_shape, name="matrix_max", scope=tik.scope_gm)
resMatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm)
cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resMatmul)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul])
resmatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm)
cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resmatmul)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resmatmul])
return tik_instance
@ -177,7 +177,7 @@ def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3,
core_m = block_idx // core_n_num
core_n = block_idx % core_n_num
res_l0c = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
name="resMatmul_L0C", scope=tik.scope_cc)
name="resmatmul_L0C", scope=tik.scope_cc)
with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k:
if diag_opt:
k_idx = (core_n * loop_n_num + cc_n) * no_tile + thread_idx_k * ko_tile_inner
@ -219,7 +219,7 @@ def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3,
tik_instance.mmad(res_l0c, input_x1_l0a, input_x2_l0b, mo_tile * c0,
ko_tile_inner * c0, no_tile * c0, 1)
res_ub = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
name="resMatmul_ub", scope=tik.scope_ubuf)
name="resmatmul_ub", scope=tik.scope_ubuf)
tik_instance.data_move(res_ub, res_l0c, 0, 1, no_tile * mo_tile, 0, 0)
input_3_local_ub = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="input_3_local_ub")

View File

@ -110,13 +110,13 @@ def cus_matmul_cube(input_x1, input_x2, bias=None, output_y=None, trans_a=False,
result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a,
format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias)
with tvm.target.cce():
schedule = generic.auto_schedule(result)
tensor_list = [tensor_a, tensor_b, result]
if shape_bias:
tensor_list = [tensor_a, tensor_b, tensor_bias, result]
with tvm.target.cce():
schedule = generic.auto_schedule(result)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
@ -163,19 +163,21 @@ def _get_block(shape_a, shape_b, trans_a, trans_b):
km_shape = shape_a[len(shape_a) - 1]
kn_shape = shape_b[len(shape_a) - 2]
n_shape = shape_b[len(shape_a) - 1]
block_in = cce.BLOCK_IN
block_out = cce.BLOCK_OUT
if trans_a and km_shape == 1:
block_in = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in = cce.BLOCK_VECTOR
block_in = cce.BLOCK_IN
if trans_b and kn_shape == 1:
block_out = cce.BLOCK_VECTOR
if not trans_b and n_shape == 1:
block_out = cce.BLOCK_VECTOR
if trans_a and km_shape == 1:
block_in = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in = cce.BLOCK_VECTOR
return block_in, block_out

View File

@ -82,7 +82,6 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
kernel_name="minmax_update_perchannel"):
"""MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")

View File

@ -39,12 +39,8 @@ def _get_tik():
return tik_instance
@op_info_register(cus_transpose02314_op_info)
def cus_transpose02314(input_x, output, kernel_name="cus_transpose021354"):
"""CusTranspose02314"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
input_x_shape = tuple(input_x_shape)
def _error_feedback(input_x_shape):
"""error feedback"""
support_shape = [(32, 128, 7, 7, 16),
(32, 32, 7, 7, 16),
(32, 32, 14, 14, 16),
@ -60,6 +56,16 @@ def cus_transpose02314(input_x, output, kernel_name="cus_transpose021354"):
if input_x_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_x_shape))
@op_info_register(cus_transpose02314_op_info)
def cus_transpose02314(input_x, output, kernel_name="cus_transpose021354"):
"""CusTranspose02314"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
input_x_shape = tuple(input_x_shape)
_error_feedback(input_x_shape)
tik_instance = _get_tik()
input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm)
@ -302,6 +308,7 @@ def shape9(tik_instance, input_x, res, dtype):
def shape10(tik_instance, input_x, res, dtype):
"""input shape (32, 32, 14, 14, 16)"""
def _inner_compute(split_index):
input_x_ub = tik_instance.Tensor(dtype, [1, 32, 2, 14, 16], name="input_1_local_ub",
scope=tik.scope_ubuf)
@ -323,6 +330,7 @@ def shape10(tik_instance, input_x, res, dtype):
def shape11(tik_instance, input_x, res, dtype):
"""input shape (32, 64, 14, 14, 16)"""
def _inner_compute(split_index, block_idx):
input_x_ub = tik_instance.Tensor(dtype, [1, 64, 2, 14, 16], name="input_1_local_ub",
scope=tik.scope_ubuf)