diff --git a/example/resnet50_imagenet2012_THOR/config.py b/example/resnet50_imagenet2012_THOR/config.py index 6c664891f76..fc01287cc85 100644 --- a/example/resnet50_imagenet2012_THOR/config.py +++ b/example/resnet50_imagenet2012_THOR/config.py @@ -31,15 +31,7 @@ config = ed({ "save_checkpoint_steps": 5004, "keep_checkpoint_max": 20, "save_checkpoint_path": "./", - "lr_init": 0.01, - "lr_end": 0.00001, - "lr_max": 0.1, - "warmup_epochs": 0, - "lr_decay_mode": "cosine", "label_smooth": 1, "label_smooth_factor": 0.1, - "lr": 0.1, - "T_max": 90, - "eta_min": 0, - "frequency": 278 + "frequency": 834 }) diff --git a/example/resnet50_imagenet2012_THOR/model/thor.py b/example/resnet50_imagenet2012_THOR/model/thor.py index 44c0fd45dba..d414f238515 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor.py +++ b/example/resnet50_imagenet2012_THOR/model/thor.py @@ -22,12 +22,6 @@ from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean -from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight -from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast -from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft -from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul -from model.grad_reducer_thor import DistributedGradReducerThor - momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -68,10 +62,10 @@ class THOR(Optimizer): self.matrix_G = ParameterTuple(matrix_G) self.A_inv_max = ParameterTuple(A_inv_max) self.G_inv_max = ParameterTuple(G_inv_max) - self.cube_matmul_left = CusMatMulCubeFraczLeftCast() - self.cube_matmul_left_fc = CusMatMulCubeDenseLeft() - self.cube_matmul_right_fc = CusMatMulCubeDenseRight() - self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul() + self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() + self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() + self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() + self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() diff --git a/example/resnet50_imagenet2012_THOR/model/thor_layer.py b/example/resnet50_imagenet2012_THOR/model/thor_layer.py index 8097d729ea1..fea74605b68 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor_layer.py +++ b/example/resnet50_imagenet2012_THOR/model/thor_layer.py @@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation from mindspore.ops import operations as P - -from cus_ops.cus_batch_matmul import CusBatchMatMul -from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm -from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1 -from cus_ops.cus_img2col import CusImg2Col -from cus_ops.cus_matmul_cube import CusMatMulCube -from cus_ops.cus_matrix_combine import CusMatrixCombine -from cus_ops.cus_transpose02314 import CusTranspose02314 - import numpy as np C0 = 16 - def caculate_device_shape(matrix_dim, channel, is_A): ll = (0) if is_A: @@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv): group=self.group ) - self.img2col = CusImg2Col(ksizes=ksizes, strides=strides) - self.cube_matmul = CusMatMulCube(transpose_a=True) - self.matrix_combine = CusMatrixCombine() - self.cholesky = CusCholeskyTrsm() - self.transpose02314 = CusTranspose02314() + self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() + self.transpose02314 = P.CusTranspose02314() self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] self.matrix_G_dim = self.out_channels self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, @@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv): self.mul = P.Mul() self.cast = P.Cast() self.damping = Tensor(damping) - self.vector_matmul = CusBatchMatMul() + self.vector_matmul = P.CusBatchMatMul() self.diag_block_dim = 128 self.channels_slice_flag = False if self.in_channels % C0 != 0: @@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv): self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) - self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) - self.fused_abs_max2 = CusFusedAbsMax1() + self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) + self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() self.sqrt = P.Sqrt() @@ -375,9 +365,9 @@ class Dense_Thor(Cell): self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) self.matmul = P.MatMul(transpose_b=True) - self.cube_matmul = CusMatMulCube(transpose_a=True) - self.matrix_combine = CusMatrixCombine() - self.cholesky = CusCholeskyTrsm() + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() self.shape = P.Shape() self.reshape = P.Reshape() self.transpose = P.Transpose() @@ -386,7 +376,7 @@ class Dense_Thor(Cell): self.cast = P.Cast() self.damping = Tensor(damping) self.loss_scale = Tensor(1 / loss_scale, mstype.float16) - self.vector_matmul = CusBatchMatMul() + self.vector_matmul = P.CusBatchMatMul() self.pad = P.Pad(((0, 24), (0, 24))) self.pad1 = P.Pad(((0, 8), (0, 8))) self.slice = P.Slice() @@ -396,8 +386,8 @@ class Dense_Thor(Cell): self.axis = 0 self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) - self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000]) - self.fused_abs_max2 = CusFusedAbsMax1() + self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) + self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() self.dampingA = Tensor(np.identity(2048), mstype.float32) diff --git a/example/resnet50_imagenet2012_THOR/run_distribute_train.sh b/example/resnet50_imagenet2012_THOR/run_distribute_train.sh index ae05c45dfe1..e39034a9127 100644 --- a/example/resnet50_imagenet2012_THOR/run_distribute_train.sh +++ b/example/resnet50_imagenet2012_THOR/run_distribute_train.sh @@ -45,8 +45,7 @@ do mkdir ./train_parallel$i cp *.py ./train_parallel$i cp *.sh ./train_parallel$i - cp -r second_order ./train_parallel$i/second_order - cp -r test_ops ./train_parallel$i/test_ops + cp -r model ./train_parallel$i cd ./train_parallel$i || exit echo "start training for rank $RANK_ID, device $DEVICE_ID" diff --git a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py new file mode 100644 index 00000000000..97982c53cf5 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py @@ -0,0 +1,257 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""batch_matmul_impl""" + +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("batchmatmul.so") \ + .compute_cost(10) \ + .kernel_name("CusBatchMatMul") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +def _get_flattern_shape(shape): + """_get_flattern_shape""" + flattern_shape = 1 + for dim in shape: + flattern_shape *= dim + return (flattern_shape,) + + +def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): + """_inner_matmul_new""" + input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) + t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) + with tik_instance.for_range(0, 2) as vec_i: + tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, 64, 1, 1, 16, 0) + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: + input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB", + scope=tik.scope_ubuf) + t_1_local_UB = input_2_local_UB + bisec_last_axis_local_UB = input_2_local_UB + matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], name="matmul_hybrid_f_t_local_UB", + scope=tik.scope_ubuf) + matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64], + name="matmul_hybrid_f_t_local_UB_dst_tmp", + scope=tik.scope_ubuf) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8) + tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 8192], 0, 1, 1024, 0, 0) + tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8) + tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1, + 16, 16, 16) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8) + with tik_instance.for_range(0, 64) as cc6: + tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], bisec_last_axis_local_UB[cc6 * 128], + 1, 1, 1, 8) + tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp, + matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(res[res_index + thread_idx2 * 64], + matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) + + +def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): + """_inner_matmul_new_1_64_32_64""" + input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: + input_2_local_UB = tik_instance.Tensor(dtype, [32 * 64], name="input_2_local_UB", + scope=tik.scope_ubuf) + t_1_local_UB = input_2_local_UB + matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [32], name="matmul_hybrid_f_t_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 2048], 0, 1, 256, 0, 0) + tik_instance.vmul(64, t_1_local_UB, input_1_local_UB, input_2_local_UB, 32, 1, 1, 1, 8, 0, 8) + with tik_instance.for_range(0, 32) as cc6: + tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB[cc6], t_1_local_UB[cc6 * 64], + 1, 1, 1, 8) + tik_instance.data_move(res[res_index + thread_idx2 * 32], + matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0) + + +@op_info_register(cus_batchmatmul_op_info) +def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): + """CusBatchMatMul""" + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + x1_shape = input_x1.get("shape") + dtype = input_x1.get("dtype").lower() + x2_shape = input_x2.get("shape") + if dtype != input_x2.get("dtype").lower(): + raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( + dtype, input_x2.get("dtype").lower())) + input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) + support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), + ((36, 128, 128), (36, 128, 128), "float32", False, True), + ((5, 128, 128), (5, 128, 128), "float32", False, True), + ((18, 128, 128), (18, 128, 128), "float32", False, True), + ((16, 128, 128), (16, 128, 128), "float32", False, True), + ((9, 128, 128), (9, 128, 128), "float32", False, True), + ((1, 64, 64), (1, 64, 64), "float32", False, True), + ((1, 128, 128), (1, 128, 128), "float32", False, True), + ((4, 128, 128), (4, 128, 128), "float32", False, True), + ((2, 128, 128), (2, 128, 128), "float32", False, True)] + if input_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + # if not transpose_a and transpose_b: + batch, m, k = x1_shape + + input1_shape = _get_flattern_shape(x1_shape) + input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) + input2_shape = _get_flattern_shape(x2_shape) + input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) + + output_shape = x1_shape + res_shape = _get_flattern_shape(output_shape) + res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) + + if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 18, block_num=18) as block_idx: + with tik_instance.for_range(0, 2) as cc0: + with tik_instance.for_range(0, 128, thread_num=2) as cc1: + input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 + input2_index = block_idx * 32768 + cc0 * 16384 + res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 30, block_num=30) as block_idx: + with tik_instance.for_range(0, 11) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx: + with tik_instance.if_scope(((((block_idx % 6) * 22) + (cc1_db * 2) + thread_idx) < 128)): + input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", + scope=tik.scope_ubuf) + t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input1[ + (block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + thread_idx * 128], 0, 1, + 16, 0, 0) + with tik_instance.for_range(0, 2) as vec_i: + tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, + 64, 1, 1, 16, 0) + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: + input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB", + scope=tik.scope_ubuf) + t_1_local_UB = input_2_local_UB + bisec_last_axis_local_UB = input_2_local_UB + matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], + name="matmul_hybrid_f_t_local_UB", + scope=tik.scope_ubuf) + matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64], + name="matmul_hybrid_f_t_local_UB_dst_tmp", + scope=tik.scope_ubuf) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8) + tik_instance.data_move(input_2_local_UB, + input2[(block_idx // 6) * 16384 + thread_idx2 * 8192], 0, 1, + 1024, 0, 0) + tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8) + tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1, + 16, 16, 16) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8) + with tik_instance.for_range(0, 64) as cc6: + tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], + bisec_last_axis_local_UB[cc6 * 128], + 1, 1, 1, 8) + tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp, + matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move( + res[(block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + + thread_idx * 128 + thread_idx2 * 64], + matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) + + if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 18, block_num=18) as block_idx: + with tik_instance.for_range(0, 128, thread_num=2) as cc0: + input1_index = block_idx * 16384 + cc0 * 128 + input2_index = block_idx * 16384 + res_index = block_idx * 16384 + cc0 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 27, block_num=27) as block_idx: + with tik_instance.for_range(0, 42, thread_num=2) as cc0: + input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128 + input2_index = (block_idx // 3) * 16384 + res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + with tik_instance.if_scope((block_idx % 3) < 2): + input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128 + input2_index = (block_idx // 3) * 16384 + res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 2, thread_num=2) as cc0: + input1_index = block_idx * 128 + cc0 * 64 + input2_index = 0 + res_index = block_idx * 128 + cc0 * 64 + _inner_matmul_new_1_64_32_64(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True), + ((2, 128, 128), (2, 128, 128), "float32", False, True), + ((4, 128, 128), (4, 128, 128), "float32", False, True), + ((8, 128, 128), (8, 128, 128), "float32", False, True), + ((16, 128, 128), (16, 128, 128), "float32", False, True) + ] + if input_shape in input_shape_list: + block_num = 32 + input1_unit_size = 128 + input2_unint_size = 128 * 128 + with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: + block_process_ele_num = (batch * m * k) // block_num + loop_time = (batch * m * k) // block_num // input1_unit_size + thread_num = 2 + with tik_instance.for_range(0, loop_time, thread_num=thread_num) as cc0: + input1_index = block_idx * block_process_ele_num + cc0 * input1_unit_size + if batch > 1: + input2_index = block_idx // (block_num // batch) * input2_unint_size + else: + input2_index = 0 + res_index = block_idx * block_process_ele_num + cc0 * input1_unit_size + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + tik_instance.BuildCCE(kernel_name, inputs=[input1, input2], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py b/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py new file mode 100644 index 00000000000..71dd1ccb2d5 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusCholeskyTrsm""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("choleskytrsm.so") \ + .compute_cost(10) \ + .kernel_name("CusCholeskyTrsm") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_cholesky_trsm_op_info) +def CusCholeskyTrsm(input_x, output, kernel_name): + """CusCholeskyTrsm""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + split_dim = 128 + matrix_dim = input_x_shape[0] + split_dim = min(matrix_dim, split_dim) + vector_repeat_times = int(split_dim // 64) + blocks = int(matrix_dim // split_dim) + if blocks == 0: + blocks = 1 + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="input_x_ub", scope=tik.scope_ubuf) + temp_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="temp_ub", scope=tik.scope_ubuf) + assist_1_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_1_ub", scope=tik.scope_ubuf) + assist_2_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_2_ub", scope=tik.scope_ubuf) + with tik_instance.for_range(0, split_dim) as i: + tik_instance.data_move(input_x_ub[i, 0], input_x[block_index * split_dim + i, block_index * split_dim], 0, + 1, vector_repeat_times * 8, 0, 0) + scalar1 = tik_instance.Scalar("float32", init_value=-0.5) + + with tik_instance.for_range(0, split_dim) as i: + scalar2 = tik_instance.Scalar("float32") + tik_instance.vln(64, assist_1_ub[0], input_x_ub[i, 0], vector_repeat_times, 1, 1, 8, 8) + tik_instance.vmuls(64, assist_2_ub[0], assist_1_ub[0], scalar1, vector_repeat_times, 1, 1, 8, 8) + tik_instance.vexp(64, assist_1_ub[0], assist_2_ub[0], vector_repeat_times, 1, 1, 8, 8) + scalar2.set_as(assist_1_ub[i]) + tik_instance.vmuls(64, input_x_ub[i, 0], input_x_ub[i, 0], scalar2, vector_repeat_times, 1, 1, 8, 8) + with tik_instance.for_range(i + 1, split_dim) as j: + scalar3 = tik_instance.Scalar("float32") + scalar3.set_as(input_x_ub[i, j]) + tik_instance.vmuls(64, temp_ub[j, 0], input_x_ub[i, 0], scalar3, vector_repeat_times, 1, 1, 8, 8) + tik_instance.vsub(64, input_x_ub[i + 1, 0], input_x_ub[i + 1, 0], temp_ub[i + 1, 0], + (split_dim - 1 - i) * vector_repeat_times, 1, 1, 1, 8, 8, 8) + + zero = tik_instance.Scalar("float32") + zero.set_as(0.0) + one = tik_instance.Scalar("float32") + one.set_as(1.0) + with tik_instance.for_range(0, split_dim) as i: + tik_instance.vector_dup(64, temp_ub[i, 0], zero, vector_repeat_times, 1, 8) + temp_ub.__setitem__(i * split_dim + i, one) + + chol_diag_element_final = tik_instance.Scalar("float32") + chol_diag_element_final.set_as(input_x_ub[split_dim * split_dim - 1]) + trsm_diag_element = tik_instance.Scalar("float32") + trsm_diag_element.set_as(1.0 / chol_diag_element_final) + temp_ub.__setitem__(split_dim * split_dim - 1, trsm_diag_element) + + with tik_instance.for_range(1, split_dim) as i: + index = split_dim - i - 1 + tik_instance.vector_dup(64, assist_1_ub, zero, vector_repeat_times, 1, 8) + with tik_instance.for_range(0, i) as j: + chol_diag_element_loop = tik_instance.Scalar("float32") + chol_diag_element_loop.set_as(input_x_ub[index, index + 1 + j]) + tik_instance.vmuls(64, assist_2_ub, temp_ub[j + index + 1, 0], chol_diag_element_loop, + vector_repeat_times, 1, 1, 8, 8) + tik_instance.vadd(64, assist_1_ub, assist_2_ub, assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, 8) + temp_scalar = tik_instance.Scalar("float32") + temp_scalar.set_as(input_x_ub[index, index]) + chol_diag_element = tik_instance.Scalar("float32") + chol_diag_element.set_as(1.0 / temp_scalar) + tik_instance.vsub(64, temp_ub[index, 0], temp_ub[index, 0], assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, + 8) + tik_instance.vmuls(64, temp_ub[index, 0], temp_ub[index, 0], chol_diag_element, vector_repeat_times, 1, 1, + 8, 8) + + tik_instance.data_move(res[block_index, 0, 0], temp_ub, 0, 1, 8 * vector_repeat_times * split_dim, 0, 0) + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py new file mode 100644 index 00000000000..f4b8d44063b --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py @@ -0,0 +1,1082 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusFusedAbsMax1""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fusedabsmax1.so") \ + .compute_cost(10) \ + .kernel_name("CusFusedAbsMax1") \ + .partial_flag(True) \ + .attr("origin_shape", "required", "listInt", "all") \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_fused_abs_max1_op_info) +def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): + """CusFusedAbsMax1""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + if len(input_x_shape) > 2: + if (input_x_shape[0] == 1 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 4 and input_x_shape[1] == 16) or (input_x_shape[0] == 16 and input_x_shape[1] == 4): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 2 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 16 and input_x_shape[1] == 8): + if origin_shape[0] == 147 and ( + input_x_shape[0] == 2 and input_x_shape[1] == 128 and input_x_shape[2] == 128): + assert origin_shape[0] == 147 + assert origin_shape[1] == 147 + phase_1 = 16384 + phase_2 = 1216 + blocks = 32 + each_block_element = phase_1 // blocks + 64 + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[512 * block_index], 0, 1, 512 // 8, 0, 0) + line_id = block_index % 19 + tik_instance.data_move(input_x_ub[512], input_x[16384 + 128 * line_id], 0, 1, 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(19, input_x_ub, input_x_ub, input_x_ub[512], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + else: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 4 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 8 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 8): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( + input_x_shape[0] == 16 and input_x_shape[1] == 32): + if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ + 0] == 1000: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + blocks = 32 + each_block_element = 7 * 128 * 128 // 32 + 4 * 128 + phase_1 = 7 * 128 * 128 // 32 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + move_idx = block_index % 8 + tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, + 128 // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + vmask = 1000 - 7 * 128 - 64 + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], + input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, + 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + + elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ + 0] == 1001: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + blocks = 32 + each_block_element = 7 * 128 * 128 // 32 + 4 * 128 + phase_1 = 7 * 128 * 128 // 32 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + move_idx = block_index % 9 + tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, + 128 // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + vmask = 1001 - 7 * 128 - 64 + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], + input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, + 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + else: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( + input_x_shape[0] == 64 and input_x_shape[1] == 16): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 5 and input_x_shape[1] == 128 and input_x_shape[2] == 128 and origin_shape[0] == 576: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 69632 + blocks = 32 + each_block_element = total_elements // blocks + phase_1 = 2048 + phase_2 = 128 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[65536 + phase_2 * block_index * 2], 0, 1, 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1 + 64], input_x[65536 + 128 + phase_2 * block_index * 2], 0, 1, + 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub[2048], input_x_ub[2048], input_x_ub[2048 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 9 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 72 and input_x_shape[1] == 8): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 18 and input_x_shape[1] == 128 and input_x_shape[2] == 128: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 36 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 144 and input_x_shape[1] == 16): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, + 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 1024], 16, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 512], 8, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 256], 4, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 128], 2, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 128 and input_x_shape[1] == 63: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 * 3 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 3 * 64], input_x_ub[repeat_time_1 * 3 * 64], + repeat_time_2, 1, 1, 8, 8) + loop_size = each_block_element // 16384 + with tik_instance.for_range(0, loop_size) as loop_idx: + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, loop_size - 1) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * (loop_idx + 1)], 1, 1, 1, 1, 8, 8, + 8) + tail_element = each_block_element - 16384 * loop_size + repeats = tail_element // 64 + with tik_instance.for_range(0, repeats) as i: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * loop_size + i * 64], 1, 1, 1, 1, 8, + 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, input_x_ub[64 + cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[2048 + 64], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[1024 + 64], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[512 + 64], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[256 + 64], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[128 + 64], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[64 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], input_x_ub[64], 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 32 and input_x_shape[1] == 128) or ( + input_x_shape[0] == 128 and input_x_shape[1] == 32): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 * 2 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + loop_size = each_block_element // 16384 + with tik_instance.for_range(0, loop_size) as loop_idx: + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, loop_size - 1) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * (loop_idx + 1)], 1, 1, 1, 1, 8, 8, + 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 288 and input_x_shape[1] == 32: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + assist_ub = tik_instance.Tensor("float32", (64,), name="assist_ub", scope=tik.scope_ubuf) + zero = tik_instance.Scalar("float32") + zero.set_as(0) + tik_instance.vector_dup(64, assist_ub, zero, 1, 1, 8) + input_x_ub = tik_instance.Tensor("float32", (32768,), name="input_x_ub", scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + repeat_time_1 = 255 + repeat_time_2 = 32768 // 64 - 255 * 2 + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 0], 0, 1, 4096, 0, 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 32768], 0, 1, 4096, 0, + 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 65536], 0, 1, 1024, 0, + 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, 128, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(assist_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 64 and input_x_shape[1] == 128: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + assist_ub = tik_instance.Tensor("float32", (64,), name="assist_ub", scope=tik.scope_ubuf) + zero = tik_instance.Scalar("float32") + zero.set_as(0) + tik_instance.vector_dup(64, assist_ub, zero, 1, 1, 8) + input_x_ub = tik_instance.Tensor("float32", (32768,), name="input_x_ub", scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + repeat_time_1 = 255 + repeat_time_2 = 32768 // 64 - 255 * 2 + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 0], 0, 1, 4096, 0, 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 32768], 0, 1, 4096, 0, + 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(assist_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 64 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 64): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, + 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 36 and input_x_shape[1] == 4: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 4 and input_x_shape[1] == 4: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 49 and input_x_shape[1] == 4: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, 24, 1, 1, 8, 8) + tik_instance.vabs(32, input_x_ub[1536], input_x_ub[1536], 1, 1, 1, 8, 8) + tik_instance.vmax(32, input_x_ub[1504], input_x_ub[1504], input_x_ub[1536], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 1 and input_x_shape[1] == 64 and input_x_shape[2] == 64: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + + else: + raise RuntimeError("UnSupportedShape") + elif len(input_x_shape) == 2 and (input_x_shape[0] == 32 and input_x_shape[1] == 64): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + input_x_ub = tik_instance.Tensor("float32", (32 * 64,), name="input_x_ub", scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x, 0, 1, 256, 0, 0) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(res[0], input_x_ub, 0, 1, 1, 0, 0) + else: + raise RuntimeError("UnSupportedShape") + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/img2col_impl.py b/mindspore/ops/_op_impl/_custom_op/img2col_impl.py new file mode 100644 index 00000000000..433e3355650 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/img2col_impl.py @@ -0,0 +1,1151 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusImg2ColNC1HWC0""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_img2col_info = TBERegOp("CusImg2Col") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("img2col.so") \ + .compute_cost(10) \ + .kernel_name("CusImg2Col") \ + .partial_flag(True) \ + .attr("ksizes", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("dilates", "required", "listInt", "all") \ + .attr("mode", "required", "str", "all") \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(cus_img2col_info) +def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img2col"): + """CusImg2Col""" + input_x_shape = input_x.get("shape") + input_x_dtype = input_x.get("dtype") + N, C1, H, W, C0 = input_x_shape + C = C1 * C0 + padding = 'SAME' + _, filter_h, filter_w, _ = ksizes + _, stride_h, stride_w, _ = strides + _, dilation_filter_h, dilation_filter_w, _ = dilates + + input_shape = (tuple(input_x_shape), input_x_dtype, (filter_h, filter_w), (stride_h, stride_w)) + supported_shape = [((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2)), + ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2)), + ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1)), + ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2)), + ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1)), + ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2)), + ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1)), + ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1)), + ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1)), + ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1)), + ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1)), + ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2)), + ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1)), + ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2)), + ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2)), + ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)), + ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)), + ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)), + ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)),] + + if input_shape not in supported_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + output_tmp = [N * int(H // stride_h) * int(W // stride_w), filter_h * filter_w * C] + output_shape = [output_tmp[1] // 16, output_tmp[0] // 16, 16, 16] + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float16", output_shape, name="res", scope=tik.scope_gm) + + if input_shape == ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2)): + pad = [3, 3, 3, 3] + l1_h = 56 + l1_w = 224 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53760,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 7) as eeb: + with tik_instance.for_range(0, 7) as cc0: + temp = eeb % 2 + rep = ((55 - temp - (-3 + eeb)) // 2 + 1) * 7 + fetch_filter_w = cc0 + fetch_filter_h = eeb + left_top_w = -3 + left_top_h = -3 + + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, rep) as cc1: + tik_instance.data_move(res[cc0 + eeb * 7, cc1 + 784 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[cc1 * 256], 0, 1, 16, 0, 0) + + with tik_instance.for_range(1, 3) as eeb0: + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 56 * eeb0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 7) as eeb: + with tik_instance.for_range(0, 7) as cc0: + temp = eeb % 2 + rep_prefix = ((55 - temp - (-3 + eeb)) // 2 + 1) * 7 + rep = 196 + fetch_filter_w = cc0 + fetch_filter_h = eeb + left_top_w = -3 + + left_top_h = 1 + ((55 - temp - (-3 + eeb)) // 2 - 29) * 2 + + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, rep) as cc1: + tik_instance.data_move( + res[cc0 + eeb * 7, cc1 + rep_prefix + (eeb0 - 1) * rep + 784 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[cc1 * 256], 0, 1, 16, 0, 0) + + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 56 * 3, 0, 0], 0, 1, 12544, 0, 0) + + with tik_instance.for_range(0, 7) as eeb: + with tik_instance.for_range(0, 7) as cc0: + temp = eeb % 2 + rep_prefix = ((55 - temp - (-3 + eeb)) // 2 + 1) * 7 + 196 * 2 + rep = 784 - rep_prefix + fetch_filter_w = cc0 + fetch_filter_h = eeb + left_top_w = -3 + left_top_h = 1 + ((55 - temp - (-3 + eeb)) // 2 - 29) * 2 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, rep) as cc1: + tik_instance.data_move(res[cc0 + eeb * 7, cc1 + rep_prefix + 784 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[cc1 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 56 + l1_w = 56 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 9) as eeb0: + rep = 196 + fetch_filter_w = eeb0 % 3 + fetch_filter_h = eeb0 // 3 + left_top_w = -1 + left_top_h = -1 + with tik_instance.for_range(0, 4) as eeb1: + c1_index = eeb1 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, rep) as i: + tik_instance.data_move(res[eeb1 * 9 + eeb0, i + 196 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2)): + pad = [1, 1, 1, 1] + l1_h = 56 + l1_w = 56 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (401408,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (112896,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 8) as eeb0: + with tik_instance.for_range(0, 9) as eeb1: + rep = 49 + fetch_filter_w = eeb1 % 3 + fetch_filter_h = eeb1 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = eeb0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[49 * 256 * eeb1], input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 9) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb1 + eeb0 * 9, 49 * block_index + i, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 28 + l1_w = 28 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (112896,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 6272, 0, 0) + with tik_instance.for_range(0, 8) as eeb0: + with tik_instance.for_range(0, 9) as eeb1: + rep = 49 + fetch_filter_w = eeb1 % 3 + fetch_filter_h = eeb1 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = eeb0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[49 * 256 * eeb1], input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 9) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb1 + eeb0 * 9, 49 * block_index + i, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2)): + pad = [1, 1, 1, 1] + l1_h = 28 + l1_w = 28 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + eeb0 = block_index % 2 + eeb1 = block_index // 2 + input_1_1_local_L1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_1_local_L1[i * 12544], input_x[i + 16 * eeb0, eeb1, 0, 0, 0], 0, 1, 784, + 0, 0) + + with tik_instance.for_range(0, 9) as eeb3: + rep = 13 + fetch_filter_w = eeb3 % 3 + fetch_filter_h = eeb3 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 16) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * i], input_1_1_local_L1[12544 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + + with tik_instance.for_range(196 * eeb0, 196 * (eeb0 + 1)) as i: + tik_instance.data_move(res[eeb1 * 9 + eeb3, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - 196 * eeb0)], 0, 1, 16, 0, 0) + + if input_shape == ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + eeb0 = block_index % 2 + eeb1 = block_index // 2 + input_1_1_local_L1 = tik_instance.Tensor("float16", (50176,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i + 16 * eeb0, eeb1, 0, 0, 0], 0, 1, 196, + 0, 0) + + with tik_instance.for_range(0, 9) as eeb3: + rep = 13 + fetch_filter_w = eeb3 % 3 + fetch_filter_h = eeb3 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 16) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * i], input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + + with tik_instance.for_range(196 * eeb0, 196 * (eeb0 + 1)) as i: + tik_instance.data_move(res[eeb1 * 9 + eeb3, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - 196 * eeb0)], 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2)): + pad = [1, 1, 1, 1] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, block_index, 0, 0, 0], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 9) as eeb: + rep = 4 + fetch_filter_w = eeb % 3 + fetch_filter_h = eeb // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 32) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2)): + pad = [0, 0, 0, 0] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + + with tik_instance.for_range(0, 2) as eeb0: + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, block_index * 2 + eeb0, 0, 0, 0], 0, + 1, 196, 0, 0) + with tik_instance.for_range(0, 32) as i: + rep = 4 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb0 + block_index * 2, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 7 + l1_w = 7 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 784], input_x[i, block_index, 0, 0, 0], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 9) as eeb: + rep = 4 + fetch_filter_w = eeb % 3 + fetch_filter_h = eeb // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 32) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[784 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1)): + pad = [0, 0, 0, 0] + l1_h = 7 + l1_w = 7 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 4) as eeb0: + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 784], input_x[i, eeb0 + block_index * 4, 0, 0, 0], 0, + 1, 49, 0, 0) + with tik_instance.for_range(0, 32) as i: + rep = 4 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + with tik_instance.for_range(0, 32) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[784 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb0 + block_index * 4, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1)): + pad = [0, 0, 0, 0] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_2_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_2_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, block_index * 2, 0, 0, 0], 0, 1, 196, 0, + 0) + tik_instance.data_move(input_1_2_local_L1[i * 3136], input_x[i, block_index * 2 + 1, 0, 0, 0], 0, 1, + 196, 0, 0) + with tik_instance.for_range(0, 2) as eeb1: + with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * (i - eeb1 * 16)], + input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(eeb1 * 196, (eeb1 + 1) * 196) as i: + tik_instance.data_move(res[block_index * 2, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - eeb1 * 196)], 0, 1, 16, 0, 0) + + with tik_instance.for_range(0, 2) as eeb1: + with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * (i - eeb1 * 16)], + input_1_2_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(eeb1 * 196, (eeb1 + 1) * 196) as i: + tik_instance.data_move(res[block_index * 2 + 1, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - eeb1 * 196)], 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2)): + pad = [0, 0, 0, 0] + l1_h = 28 + l1_w = 28 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (401408,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 12544], input_x[i, block_index, 0, 0, 0], 0, 1, 784, 0, 0) + with tik_instance.for_range(0, 16) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * i], input_1_1_local_L1[12544 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 196) as i: + tik_instance.data_move(res[block_index, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], 0, 1, 16, 0, + 0) + + with tik_instance.for_range(16, 32) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * (i - 16)], input_1_1_local_L1[12544 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(196, 392) as i: + tik_instance.data_move(res[block_index, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * (i - 196)], 0, 1, + 16, 0, 0) + + if input_shape == ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 7 + l1_w = 7 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 784], input_x[i, block_index, 0, 0, 0], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 32) as i: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[784 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 4) + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[block_index, i, 0, 0], input_1_2_fractal_L1_local_UB[i * 256], 0, 1, 16, 0, + 0) + + if input_shape == ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 56 + l1_w = 56 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (12544 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (100352 // 2,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 4) as eeb: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1[eeb * 56 * 56 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 196) + with tik_instance.for_range(0, 196) as rep: + tik_instance.data_move(res[eeb, rep + block_index * 196, 0, 0], + input_1_1_fractal_L1_local_UB[rep * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 28 + l1_w = 28 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (6272 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (49 * 256 * 8,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 6272, 0, 0) + with tik_instance.for_range(0, 1) as eeb0: + with tik_instance.for_range(0, 8) as eeb1: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb1 * 49 * 256], + input_1_1_local_L1[(eeb1 + eeb0 * 8) * 28 * 28 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 49) + with tik_instance.for_range(0, 8) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1, i + block_index * 49, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 28 + l1_w = 28 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (49 * 256 * 8,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 4) as eeb0: + with tik_instance.for_range(0, 8) as eeb1: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb1 * 49 * 256], + input_1_1_local_L1[(eeb1 + eeb0 * 8) * 28 * 28 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 49) + with tik_instance.for_range(0, 8) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1, i + block_index * 49, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 14 + l1_w = 14 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + eeb0 = block_index % 2 + eeb1 = block_index // 2 + input_1_1_local_L1 = tik_instance.Tensor("float16", (196 * 32 * 16,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (106496 // 2,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (196 * 16 * 16,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, eeb1, 0, 0, 0], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 16) as i: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[i * 3328], + input_1_1_local_L1[i * 3136 + eeb0 * 16 * 3136], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 13) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 196) as i: + tik_instance.data_move(res[eeb1, i + 196 * eeb0, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], 0, 1, + 16, 0, 0) + + if input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 56 + l1_w = 56 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (196 * 256 * 2,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + with tik_instance.for_range(0, 2) as eeb0: + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, eeb0 * 8, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 4) as eeb1: + with tik_instance.for_range(0, 2) as eeb2: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb2 * 196 * 256], + input_1_1_local_L1[(eeb2 + eeb1 * 2) * 56 * 56 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 196) + with tik_instance.for_range(0, 2) as eeb2: + with tik_instance.for_range(0, 196) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1 * 2 + eeb2, i + block_index * 196, 0, 0], + input_1_1_fractal_L1_local_UB[256 * i + eeb2 * 196 * 256], 0, 1, 16, + 0, 0) + + if input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 56 + l1_w = 56 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (49 * 256 * 8,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + with tik_instance.for_range(0, 2) as eeb0: + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, eeb0 * 8, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 8) as eeb1: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb1 * 49 * 256], + input_1_1_local_L1[eeb1 * 56 * 56 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 49) + with tik_instance.for_range(0, 8) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1, i + block_index * 49, 0, 0], + input_1_1_fractal_L1_local_UB[256 * i + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py new file mode 100644 index 00000000000..e5c380369d0 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py @@ -0,0 +1,468 @@ +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.lang.cce +import te.platform.cce_params as cce +from te import tik +from te import tvm +from topi import generic +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubedenseleft.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeDenseLeft") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ + .get_op_info() + + +# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + k_block_size = cce.BLOCK_REDUCE + + check_list = ("float16") + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") + + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + + if m_shape == 1: + if n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + print(km_shape, kn_shape) + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + elif km_shape % k_block_size != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) + + if len(shape_bias) != 0: + if len(shape_bias) == 1: + if is_gevm or is_gemv: + if shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + else: + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == shape_len: + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("unsupport input shape now for batch bias case") + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + shb = [] + if bias_length % 16 == 0: + shb = shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + shb = shape_bias + return shb + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res + + +def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" + shape_a = input_x1.get("shape") + shape_b = input_x2.get("shape") + print("shape_a: ", shape_a) + print("shape_b: ", shape_b) + src_dtype = input_x1.get("dtype") + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + try: + trans_a_f = bool(1 - trans_a) + if src_dtype == "float32" or src_dtype == "int32": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + if trans_b: + if shape_b[0] == 1: + return False + else: + if shape_b[1] == 1: + return False + if trans_a: + if trans_b: + if shape_a[0] != shape_b[1]: + return False + elif shape_a[0] != shape_b[0]: + return False + elif trans_b: + if shape_a[1] != shape_b[1]: + return False + elif shape_a[1] != shape_b[0]: + return False + + if trans_a_f and trans_b and shape_b[1] == 1: + return False + + if src_dtype == "float16": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + + if trans_a: + m_shape = shape_a[1] + k_shape = shape_a[0] + else: + m_shape = shape_a[0] + k_shape = shape_a[1] + + if trans_b: + n_shape = shape_b[0] + k_b_shape = shape_b[1] + else: + n_shape = shape_b[1] + k_b_shape = shape_b[0] + + if k_shape != k_b_shape: + return False + + if m_shape == 1 or n_shape == 1: + if k_shape % 256 != 0: + return False + + except RuntimeError as e: + return False + + return True + + +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +# @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) +@op_info_register(matmul_cube_dense_left_op_info) +def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """ + calculating matrix multiplication with bias, C = A*B + bias, support input + data with fractal format. + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + src_dtype: str + The data type of input, support "float32", "float16" + dst_dtype: str + The data type of output, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + is_fractal: bool + If True, the input data format of a and b must be fractal format + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + + Returns + ------- + None + """ + print("!!!!come into zzt~~~~~~~!!!!") + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + shape_output = output_y.get("ori_shape") + print("============") + print(input_x1.get("format"), input_x2.get("format")) + print(shape_a, shape_b) + print("============") + if input_x2.get("format") == "FRACTAL_Z": + n, c, h, w = shape_b + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_b = [n, c1 * h * w * c0] + shape_a = [n, n] + + if input_x1.get("format") == "FRACTAL_Z": + n, c, h, w = shape_a + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_a = [n, c1 * h * w * c0] + shape_b = [c1 * h * w * c0, c1 * h * w * c0] + + if input_x2.get("format") == "FRACTAL_NZ": + shape_a = [shape_b[0], shape_b[0]] + shape_b = shape_b + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = shape_a + shape_b = [shape_a[1], shape_a[1]] + + shape_a = list(shape_a) + shape_b = list(shape_b) + + shape_a = _get_input_shape(shape_a) + shape_b = _get_input_shape(shape_b) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + + shape_a = [shape_a[1], shape_a[0]] + trans_a = bool(1 - trans_a) + + shape_b = [shape_b[1], shape_b[0]] + trans_b = bool(1 - trans_b) + + shape_bias = () + if bias is not None and bool(bias): + shape_bias = bias.get("shape") + shape_bias = list(shape_bias) + shape_bias = _get_bias(shape_bias) + + src_dtype = input_x1.get("dtype").lower() + dst_dtype = output_y.get("dtype").lower() + _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) + + m_shape = shape_a[len(shape_a) - 2] + km_shape = shape_a[len(shape_a) - 1] + kn_shape = shape_b[len(shape_a) - 2] + n_shape = shape_b[len(shape_a) - 1] + + if src_dtype == "float16": + block_reduce = cce.BLOCK_REDUCE + + block_in = cce.BLOCK_IN + block_out = cce.BLOCK_OUT + + if trans_a and km_shape == 1: + block_in = cce.BLOCK_VECTOR + + if not trans_a and m_shape == 1: + block_in = cce.BLOCK_VECTOR + + if trans_b and kn_shape == 1: + block_out = cce.BLOCK_VECTOR + + if not trans_b and n_shape == 1: + block_out = cce.BLOCK_VECTOR + + if trans_a: + shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) + else: + shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) + + if trans_b: + shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) + else: + shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + format_a = "FRACTAL_NZ" + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + format_b = "FRACTAL_NZ" + + print("=======================================") + print(shape_a_temp, shape_b_temp) + print(format_a, format_b) + print("=======================================") + tensor_bias = None + tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', + dtype=src_dtype) + tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', + dtype=src_dtype) + + if len(shape_bias) > 0: + tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', + dtype=dst_dtype) + + if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 63: + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) + resMatmul = tik_instance.Tensor("float16", shape_output, name="output", scope=tik.scope_gm) + with tik_instance.for_range(0, 32, block_num=32) as block_index: + resMatmul_local_UB = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_ubuf, + name="resMatmul_local_UB") + resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (128 * 256,), scope=tik.scope_cc, + name="resMatmul_local_UB") + input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_ca, + name="input_1_local_L1_local_L0A") + input_2_local_L1 = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cbuf, + name="input_2_local_L1") + input_1_local_L1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cbuf, + name="input_1_local_L1") + input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cb, + name="input_2_local_L1_local_L0B") + core_m_idx = block_index % 8 + core_n_idx = block_index // 8 + with tik_instance.if_scope(core_m_idx != 7): + tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 8, 128, + 55 * 16, 0) + tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, + 32, 128, 55 * 16, 0) + with tik_instance.for_range(0, 8) as cc12: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc12 * 2048], input_1_local_L1[cc12 * 256], 0, 8, + 8, 0, False) + with tik_instance.for_range(0, 2) as cc6: + with tik_instance.for_range(0, 8) as cc121: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc121 * 4096], + input_2_local_L1[cc6 * 32768 + cc121 * 256], 0, 16, 8, 0, True) + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B, 128, 128, 256, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0, 1) + tik_instance.data_move(resMatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], + resMatmul_local_UB, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2) + with tik_instance.else_scope(): + tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112, + 56 * 16, 0) + tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, + 32, 112, 56 * 16, 0) + with tik_instance.for_range(0, 7) as cc10: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc10 * 1792], input_1_local_L1[cc10 * 256], 0, 7, + 7, 0, False) + with tik_instance.for_range(0, 2) as cc5: + with tik_instance.for_range(0, 7) as cc101: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc101 * 4096], + input_2_local_L1[cc5 * 28672 + cc101 * 256], 0, 16, 7, 0, True) + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B, 112, 112, 256, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 112, 0, 0, 1) + tik_instance.data_move(resMatmul[cc5 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], + resMatmul_local_UB, 0, 16, 224 // 2, 0, 56 * 16 * 2 // 2) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resMatmul]) + return tik_instance + else: + print("come into tbe, shape is error!") + result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, + format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) + + with tvm.target.cce(): + schedule = generic.auto_schedule(result) + + tensor_list = [tensor_a, tensor_b, result] + if len(shape_bias) > 0: + tensor_list = [tensor_a, tensor_b, tensor_bias, result] + + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(schedule, config) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py new file mode 100644 index 00000000000..4a1982738d6 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubedenseright.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeDenseRight") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "required", "all") \ + .input(3, "x4", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default, + DataType.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(matmul_cube_dense_right_op_info) +def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """CusMatMulCubeDenseRight""" + shape_a_temp = (128, 63, 16, 16) + shape_b_temp = (128, 128, 16, 16) + shape_output = output_y.get("shape") + matrix_max_shape = (1,) + support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),] + shape_a_input = input_x1.get("shape") + shape_b_input = input_x2.get("shape") + matrix_max_input = input_x3.get("shape") + input_shape = (tuple(shape_a_input), tuple(shape_b_input), tuple(matrix_max_input)) + if input_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + if shape_a_temp[0] == 128 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 128: + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) + input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm) + resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) + with tik_instance.for_range(0, 32, block_num=32) as block_index: + core_m_idx = block_index // 16 + core_n_idx = block_index % 16 + matrix_max_scalar = tik_instance.Scalar("float32") + matrix_max_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="matrix_max_local_UB") + tik_instance.data_move(matrix_max_local_UB, input_x3, 0, 1, 1, 0, 0) + matrix_max_scalar.set_as(matrix_max_local_UB[0]) + + resMatmul_local_UB = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf, + name="resMatmul_local_UB") + resMatmul_local_UB1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf, + name="resMatmul_local_UB1") + + resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc, + name="resMatmul_local_UB_local_L0C") + resMatmul_local_UB_local_L0C1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc, + name="resMatmul_local_UB_local_L0C1") + + input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca, + name="input_1_local_L1_local_L0A") + input_2_local_L1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf, + name="input_2_local_L1") + input_2_local_L11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf, + name="input_2_local_L11") + + input_1_local_L1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf, + name="input_1_local_L1") + input_1_local_L11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf, + name="input_1_local_L11") + + input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb, + name="input_2_local_L1_local_L0B") + input_2_local_L1_local_L0B1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb, + name="input_2_local_L1_local_L0B1") + + with tik_instance.if_scope(core_m_idx == 0): + with tik_instance.for_range(0, 2) as cc1: + tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, + 128, 1920, 0) + tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256, 752, + 0) + with tik_instance.for_range(0, 8) as cc10: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, + 8, 8, 0, True) + with tik_instance.for_range(0, 16) as cc101: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], + 0, 8, 16, 0, False) + + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B, 256, 128, 128, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0) + tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], + matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], + matrix_max_scalar, 2, 1, 1, 8, 8) + + tik_instance.data_move(resMatmul[core_n_idx * 129024 + cc1 * 4096], resMatmul_local_UB, 0, 8, 512, + 0, 1504) + with tik_instance.else_scope(): + tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128, + 1920, 0) + tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0) + with tik_instance.for_range(0, 8) as cc10: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, 8, + 8, 0, True) + with tik_instance.for_range(0, 16) as cc101: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], 0, 8, + 16, 0, False) + + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, input_2_local_L1_local_L0B, + 256, 128, 128, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0) + tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], matrix_max_scalar, + 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], matrix_max_scalar, 2, + 1, 1, 8, 8) + + tik_instance.data_move(resMatmul[core_n_idx * 129024 + 2 * 4096], resMatmul_local_UB, 0, 8, 512, 0, + 1504) + + tik_instance.data_move(input_2_local_L11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128, + 1920, 0) + tik_instance.data_move(input_1_local_L11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0) + + with tik_instance.for_range(0, 8) as cc102: + tik_instance.load2dv1(input_2_local_L1_local_L0B1[cc102 * 2048], input_2_local_L11[cc102 * 256], 0, + 8, 8, 0, True) + with tik_instance.for_range(0, 16) as cc103: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc103 * 2048], input_1_local_L11[cc103 * 256], 0, + 8, 15, 0, False) + + tik_instance.mmad(resMatmul_local_UB_local_L0C1, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B1, 240, 128, 128, 0) + tik_instance.data_move(resMatmul_local_UB1, resMatmul_local_UB_local_L0C1, 0, 1, 120, 0, 0) + + tik_instance.vmuls(64, resMatmul_local_UB1, resMatmul_local_UB1, matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB1[255 * 64], resMatmul_local_UB1[255 * 64], matrix_max_scalar, + 225, 1, 1, 8, 8) + + tik_instance.data_move(resMatmul[core_n_idx * 129024 + 12288], resMatmul_local_UB1, 0, 8, 480, 0, 1536) + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py new file mode 100644 index 00000000000..9a30da37847 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py @@ -0,0 +1,526 @@ +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.platform.cce_params as cce +from te import tik +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +matmul_cube_fracz_left_cast_op_info = TBERegOp("CusMatMulCubeFraczLeftCast") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubefraczleftcast.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeFraczLeftCast") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F32_FracZ, DataType.F16_Default, DataType.F16_FracZ) \ + .get_op_info() + + +# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND +src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + k_block_size = cce.BLOCK_REDUCE + + check_list = ("float16") + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") + + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + + if m_shape == 1: + if n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + print(km_shape, kn_shape) + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + elif km_shape % k_block_size != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) + + if len(shape_bias): + if len(shape_bias) == 1: + if is_gevm or is_gemv: + if shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + else: + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == shape_len: + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("unsupport input shape now for batch bias case") + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + if bias_length % 16 == 0: + return shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + return shape_bias + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res + + +def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" + shape_a = input_x1.get("shape") + shape_b = input_x2.get("shape") + print("shape_a: ", shape_a) + print("shape_b: ", shape_b) + src_dtype = input_x1.get("dtype") + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + try: + trans_a_f = bool(1 - trans_a) + if src_dtype == "float32" or src_dtype == "int32": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + if trans_b: + if shape_b[0] == 1: + return False + else: + if shape_b[1] == 1: + return False + if trans_a: + if trans_b: + if shape_a[0] != shape_b[1]: + return False + elif shape_a[0] != shape_b[0]: + return False + elif trans_b: + if shape_a[1] != shape_b[1]: + return False + elif shape_a[1] != shape_b[0]: + return False + + if trans_a_f and trans_b and shape_b[1] == 1: + return False + + if src_dtype == "float16": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + + if trans_a: + m_shape = shape_a[1] + k_shape = shape_a[0] + else: + m_shape = shape_a[0] + k_shape = shape_a[1] + + if trans_b: + n_shape = shape_b[0] + k_b_shape = shape_b[1] + else: + n_shape = shape_b[1] + k_b_shape = shape_b[0] + + if k_shape != k_b_shape: + return False + + if m_shape == 1 or n_shape == 1: + if k_shape % 256 != 0: + return False + + except RuntimeError as e: + return False + + return True + + +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +@op_info_register(matmul_cube_fracz_left_cast_op_info) +def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="CusMatMulCubeFraczLeftCast"): + """ + calculating matrix multiplication with bias, C = A*B + bias, support input + data with fractal format. + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + src_dtype: str + The data type of input, support "float32", "float16" + dst_dtype: str + The data type of output, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + is_fractal: bool + If True, the input data format of a and b must be fractal format + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + + Returns + ------- + None + """ + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + print("============") + print(input_x1.get("format"), input_x2.get("format")) + print(shape_a, shape_b) + print("============") + if input_x2.get("format") == "FRACTAL_Z": + n, c, h, w = shape_b + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_b = [n, c1 * h * w * c0] + shape_a = [n, n] + + if input_x1.get("format") == "FRACTAL_Z": + n, c, h, w = shape_a + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_a = [n, c1 * h * w * c0] + shape_b = [c1 * h * w * c0, c1 * h * w * c0] + + if input_x2.get("format") == "FRACTAL_NZ": + shape_a = [shape_b[0], shape_b[0]] + shape_b = shape_b + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = shape_a + shape_b = [shape_a[1], shape_a[1]] + + shape_a = list(shape_a) + shape_b = list(shape_b) + + shape_a = _get_input_shape(shape_a) + shape_b = _get_input_shape(shape_b) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + + shape_a = [shape_a[1], shape_a[0]] + trans_a = bool(1 - trans_a) + + shape_b = [shape_b[1], shape_b[0]] + trans_b = bool(1 - trans_b) + + shape_bias = () + if bias is not None and bool(bias): + shape_bias = bias.get("shape") + shape_bias = list(shape_bias) + shape_bias = _get_bias(shape_bias) + + src_dtype = input_x1.get("dtype").lower() + _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) + + m_shape = shape_a[len(shape_a) - 2] + km_shape = shape_a[len(shape_a) - 1] + kn_shape = shape_b[len(shape_a) - 2] + n_shape = shape_b[len(shape_a) - 1] + + if src_dtype == "float16": + block_reduce = cce.BLOCK_REDUCE + + block_in = cce.BLOCK_IN + block_out = cce.BLOCK_OUT + + if trans_a and km_shape == 1: + block_in = cce.BLOCK_VECTOR + + if not trans_a and m_shape == 1: + block_in = cce.BLOCK_VECTOR + + if trans_b and kn_shape == 1: + block_out = cce.BLOCK_VECTOR + + if not trans_b and n_shape == 1: + block_out = cce.BLOCK_VECTOR + + if trans_a: + shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) + else: + shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) + + if trans_b: + shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) + else: + shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_a_temp, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_b_temp, name="right_matrix", scope=tik.scope_gm) + res_matmul = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm) + DIAG_SIZE = 128 + mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(input_x1, input_x2, DIAG_SIZE) + cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul, + mo_tile=mo_tile, ko_tile=ko_tile, no_tile=no_tile, + diag_opt=diag_opt, diag_size=DIAG_SIZE) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul]) + return tik_instance + + +def get_cus_tile_info(input_x1, input_x2, diag_size): + """get_cus_tile_info""" + tile_map = { + ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), + ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), + ((32, 32, 16, 16), (288, 32, 16, 16)): (8, 8, 12), + ((128, 128, 16, 16), (32, 128, 16, 16)): (8, 8, 16), + ((16, 16, 16, 16), (144, 16, 16, 16)): (8, 8, 9), + ((64, 64, 16, 16), (16, 64, 16, 16)): (8, 8, 4), + ((16, 16, 16, 16), (64, 16, 16, 16)): (8, 8, 4), + ((32, 32, 16, 16), (8, 32, 16, 16)): (8, 8, 1), + ((128, 128, 16, 16), (64, 128, 16, 16)): (8, 8, 16), + ((16, 16, 16, 16), (4, 16, 16, 16)): (8, 8, 1), + ((16, 16, 16, 16), (32, 16, 16, 16)): (8, 8, 2), + ((64, 64, 16, 16), (32, 64, 16, 16)): (8, 8, 8), + ((32, 32, 16, 16), (64, 32, 16, 16)): (8, 8, 8), + ((32, 32, 16, 16), (16, 32, 16, 16)): (8, 8, 2), + ((8, 8, 16, 16), (32, 8, 16, 16)): (8, 8, 1), + ((8, 8, 16, 16), (16, 8, 16, 16)): (4, 8, 1), + ((4, 4, 16, 16), (16, 4, 16, 16)): (2, 4, 1), + ((4, 4, 16, 16), (4, 4, 16, 16)): (1, 4, 1), + ((4, 4, 16, 16), (36, 4, 16, 16)): (2, 4, 3), + ((4, 4, 16, 16), (49, 4, 16, 16)): (1, 4, 7) + } + shape_info = (tuple(input_x1.shape), tuple(input_x2.shape)) + diag_opt = False + if input_x1.shape[0] * input_x1.shape[3] > diag_size: + diag_opt = True + if shape_info not in tile_map: + raise ValueError("shape %s is not supported" % str(shape_info)) + mo_tile, ko_tile, no_tile = tile_map[shape_info] + return mo_tile, ko_tile, no_tile, diag_opt + + +def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, + res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): + """cus_cube_matmul_cast""" + ko, mo, _, _ = input_x1.shape + no, ko, _, _ = input_x2.shape + c0 = input_x1.shape[-1] + diag_outer = diag_size // c0 + maxblocknum = 32 + fp32_size = 4 + fp16_size = 2 + blocksize = 32 + vectorfp32_size = 64 + if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: + raise ValueError("shape of input_x1 or input_x2 is not supported!") + if not trans_a or not trans_b: + raise ValueError("only trans_a=False and trans_b=False be supported!") + + core_m_num = mo // mo_tile + loop_n_num = no // no_tile + if loop_n_num * core_m_num <= maxblocknum: + core_n_num = loop_n_num + else: + core_n_num = maxblocknum // core_m_num + if core_n_num > 0 and loop_n_num % core_n_num == 0: + loop_n_num = loop_n_num // core_n_num + else: + raise ValueError("Does not support this scenario!") + block_num = core_m_num * core_n_num + + loop_k_num = ko // ko_tile + if diag_opt: + loop_k_num = diag_outer // ko_tile + # double buffer: + thread_num_k = 2 + loop_k_num *= thread_num_k + ko_tile_inner = ko_tile // thread_num_k + with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: + core_m = block_idx // core_n_num + core_n = block_idx % core_n_num + with tik_instance.for_range(0, loop_n_num) as cc_n: + res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], + name="resMatmul_L0C", scope=tik.scope_cc) + with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k: + # input_x2 -> input_x2_ub -(fp322fp16)-> input_x2_cast_ub -> input_x2_L1 + input_x2_ub = tik_instance.Tensor("float32", [no_tile, ko_tile_inner, c0, c0], name="input_x2_ub", + scope=tik.scope_ubuf) + if diag_opt: + k_idx = core_m * mo_tile + thread_idx_k * ko_tile_inner + else: + k_idx = thread_idx_k * ko_tile_inner + tik_instance.data_move(input_x2_ub, + input_x2[(core_n * loop_n_num + cc_n) * no_tile, + k_idx, 0, 0], + 0, no_tile, ko_tile_inner * c0 * c0 * fp32_size // blocksize, + (ko - ko_tile_inner) * c0 * c0 * fp32_size // blocksize, 0) + input_x2_cast_ub = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_cast_ub", scope=tik.scope_ubuf) + repeate_num = no_tile * ko_tile_inner * c0 * c0 // vectorfp32_size + repeate_times_max = 255 + count = 0 + while repeate_num > repeate_times_max: + tik_instance.vconv(vectorfp32_size, 'none', + input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], + input_x2_ub[count * repeate_times_max * vectorfp32_size], + repeate_times_max, + 1, 1, 4, 8) + repeate_num -= repeate_times_max + count += 1 + tik_instance.vconv(vectorfp32_size, 'none', + input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], + input_x2_ub[count * repeate_times_max * vectorfp32_size], repeate_num, + 1, 1, 4, 8) + input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x2_L1, input_x2_cast_ub, 0, 1, + no_tile * ko_tile_inner * c0 * c0 * fp16_size // blocksize, 0, 0) + # input_x1 -> input_x1_L1 + input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], + name="input_x1_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x1_L1, + input_x1[k_idx, + core_m * mo_tile, 0, 0], + 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) + # input_x2_L1 -> input_x2_L0B + input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], + name="input_x2_L0B", scope=tik.scope_cb) + with tik_instance.for_range(0, ko_tile_inner) as cc2: + tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, + ko_tile_inner, + 0, True) + # input_x1_L1 -> input_x1_L0A + input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], + name="input_x1_L0A", scope=tik.scope_ca) + with tik_instance.for_range(0, mo_tile) as cc1: + tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, + mo_tile, 0, False) + with tik_instance.if_scope(thread_idx_k == 0): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 0) + with tik_instance.else_scope(): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 1) + res_ub = tik_instance.Tensor(input_x1.dtype, [no_tile, mo_tile, c0, c0], + name="resMatmul_ub", scope=tik.scope_ubuf) + tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0, 1) + tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, core_m * mo_tile, 0, 0], + res_ub, 0, no_tile, + mo_tile * c0 * c0 * fp16_size // blocksize, 0, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py new file mode 100644 index 00000000000..79fab2c3cd2 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubefraczrightmul.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeFraczRightMul") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "required", "all") \ + .input(3, "x4", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_FracZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default, + DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(cus_matmul_cube_fracz_right_mul_op_info) +def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """CusMatMulCubeFraczRightMul""" + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x1_shape = input_x1.get("shape") + input_x1_dtype = input_x1.get("dtype").lower() + input_x2_shape = input_x2.get("shape") + input_x2_dtype = input_x2.get("dtype").lower() + input_x3_shape = input_x3.get("shape") + input_x3_dtype = input_x3.get("dtype").lower() + output_shape = output_y.get("shape") + Supported = [((72, 8, 16, 16), "float16", (72, 72, 16, 16), "float16", (1,), "float32"), + ((32, 8, 16, 16), "float16", (32, 32, 16, 16), "float16", (1,), "float32"), + ((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"), + ((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"), + ((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'), + ((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'), + ((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'), + ((64, 16, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), + ((32, 64, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), + ((32, 16, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), + ((16, 32, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), + ((16, 8, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), + ((16, 4, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), + ((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'), + ((144, 16, 16, 16), 'float16', (144, 144, 16, 16), 'float16', (1,), 'float32'), + ((128, 32, 16, 16), 'float16', (128, 128, 16, 16), 'float16', (1,), 'float32'), + ((64, 128, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), + ((32, 128, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), + ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), + ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] + input_shape = ( + tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) + if input_shape not in Supported: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + input_x1 = tik_instance.Tensor("float16", input_x1_shape, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor("float16", input_x2_shape, name="right_matrix", scope=tik.scope_gm) + input_x3 = tik_instance.Tensor("float32", input_x3_shape, name="matrix_max", scope=tik.scope_gm) + resMatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm) + cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resMatmul) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul]) + return tik_instance + + +def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, + res): + """cus_cube_matmul_right_mul""" + diag_size = 128 + ko, mo, _, _ = input_x1.shape + no, ko, _, _ = input_x2.shape + c0 = input_x1.shape[-1] + diag_outer = diag_size // c0 + if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: + raise ValueError("shape of input_x1 or input_x2 is not supported!") + + def get_cus_tile_info(input_x1, input_x2, input_x3): + """get_cus_tile_info""" + input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, + tuple(input_x3.shape), input_x3.dtype) + tile_map = { + # no diag opt: + ((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"): (4, 8, 2, 8, 4), + ((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"): (1, 4, 1, 4, 4), + ((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'): (1, 4, 2, 16, 2), + ((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'): (1, 7, 7, 4, 7), + ((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'): (2, 6, 3, 2, 12), + # diag opt: + ((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'): (16, 8, 8, 2, 12), + } + maxblocknum = 32 + diag_opt = False + if input_x2.shape[0] * input_x2.shape[3] > diag_size and input_x2.shape[0] % diag_outer == 0: + diag_opt = True + if input_shape in tile_map: + mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_ = tile_map[input_shape] + elif diag_opt: + ko_tile_ = diag_outer + no_tile_ = ko_tile_ + core_n_num_ = no // no_tile_ + core_m_num_max = maxblocknum // core_n_num_ + mo_tile_ = -1 + core_m_num_ = -1 + for i in range(core_m_num_max, 0, -1): + if mo % i == 0: + core_m_num_ = i + mo_tile_ = mo // i + break + if mo_tile_ == -1: + raise ValueError("no valid tile be found!") + while mo_tile_ > 16: + mo_tile_ = mo_tile_ // 2 + else: + raise ValueError("please add tile config to the tile_map") + print("shape: %s, tile: %s" % (input_shape, str((mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, + diag_opt)))) + return mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, diag_opt + + mo_tile, ko_tile, no_tile, core_m_num, core_n_num, diag_opt = get_cus_tile_info(input_x1, input_x2, input_x3) + fp32_size = 4 + fp16_size = 2 + blocksize = 32 + vectorfp32_size = 64 + loop_n_num_total = no // no_tile + loop_m_num_total = mo // mo_tile + if loop_n_num_total % core_n_num != 0 or loop_m_num_total % core_m_num != 0: + raise ValueError("Does not support this scenario!") + loop_n_num = loop_n_num_total // core_n_num + loop_m_num = loop_m_num_total // core_m_num + block_num = core_n_num * core_m_num + loop_k_num = ko // ko_tile + if diag_opt: + loop_k_num = diag_outer // ko_tile + # double buffer: + thread_num_k = 2 + if ko_tile % 2 == 0: + loop_k_num *= thread_num_k + ko_tile_inner = ko_tile // thread_num_k + else: + ko_tile_inner = ko_tile + ko_tile *= thread_num_k + with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: + core_m = block_idx // core_n_num + core_n = block_idx % core_n_num + with tik_instance.for_range(0, loop_m_num) as cc_m: + with tik_instance.for_range(0, loop_n_num) as cc_n: + res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], + name="resMatmul_L0C", scope=tik.scope_cc) + with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k: + if diag_opt: + k_idx = (core_n * loop_n_num + cc_n) * no_tile + thread_idx_k * ko_tile_inner + else: + k_idx = thread_idx_k * ko_tile_inner + # input_x1 -> input_x1_L1 + input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], + name="input_x1_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x1_L1, + input_x1[k_idx, + (core_m * loop_m_num + cc_m) * mo_tile, 0, 0], + 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) + # input_x2 -> input_x2_L1 + input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x2_L1, + input_x2[(core_n * loop_n_num + cc_n) * no_tile, + k_idx, 0, 0], + 0, no_tile, ko_tile_inner * c0 * c0 * fp16_size // blocksize, + (ko - ko_tile_inner) * c0 * c0 * fp16_size // blocksize, 0) + # input_x1_L1 -> input_x1_L0A + input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], + name="input_x1_L0A", scope=tik.scope_ca) + with tik_instance.for_range(0, mo_tile) as cc1: + tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, + mo_tile, 0, False) + # input_x2_L1 -> input_x2_L0B + input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], + name="input_x2_L0B", scope=tik.scope_cb) + with tik_instance.for_range(0, ko_tile_inner) as cc2: + tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, + ko_tile_inner, + 0, True) + with tik_instance.if_scope(thread_idx_k == 0): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 0) + with tik_instance.else_scope(): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 1) + res_ub = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], + name="resMatmul_ub", scope=tik.scope_ubuf) + tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0) + + input_3_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="input_3_local_UB") + tik_instance.data_move(input_3_local_UB, input_x3, 0, 1, 1, 0, 0) + matrix_max_scalar = tik_instance.Scalar("float32") + matrix_max_scalar.set_as(input_3_local_UB[0]) + repeate_num = no_tile * mo_tile * c0 * c0 // vectorfp32_size + repeate_times_max = 255 + count = 0 + while repeate_num > repeate_times_max: + tik_instance.vmuls(vectorfp32_size, + res_ub[count * repeate_times_max * vectorfp32_size], + res_ub[count * repeate_times_max * vectorfp32_size], + matrix_max_scalar, repeate_times_max, 1, 1, 8, 8) + repeate_num -= repeate_times_max + count += 1 + tik_instance.vmuls(vectorfp32_size, + res_ub[count * repeate_times_max * vectorfp32_size], + res_ub[count * repeate_times_max * vectorfp32_size], + matrix_max_scalar, repeate_num, 1, 1, 8, 8) + + tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, + (core_m * loop_m_num + cc_m) * mo_tile, 0, 0], + res_ub, 0, no_tile, + mo_tile * c0 * c0 * fp32_size // blocksize, 0, + (mo - mo_tile) * c0 * c0 * fp32_size // blocksize) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py new file mode 100644 index 00000000000..603ed287f6e --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import +from impl.matmul_vector import matmul_vector_cce +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.lang.cce +import te.platform.cce_params as cce +from te import tvm +from topi import generic +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +matmul_cube_op_info = TBERegOp("CusMatMulCube") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcube.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCube") \ + .partial_flag(True) \ + .attr("transpose_a", "required", "bool", "all") \ + .attr("transpose_b", "required", "bool", "all") \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_FracNZ) \ + .get_op_info() + + +# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + k_block_size = cce.BLOCK_REDUCE + + check_list = ("float16") + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") + + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + + if m_shape == 1: + if n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + elif km_shape % k_block_size != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) + + if len(shape_bias): + if len(shape_bias) == 1: + if is_gevm or is_gemv: + if shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + else: + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == shape_len: + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("unsupport input shape now for batch bias case") + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + if bias_length % 16 == 0: + return shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + return shape_bias + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res + + +def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" + shape_a = input_x1.get("shape") + shape_b = input_x2.get("shape") + print("shape_a: ", shape_a) + print("shape_b: ", shape_b) + src_dtype = input_x1.get("dtype") + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + try: + trans_a_f = bool(1 - trans_a) + if src_dtype == "float32" or src_dtype == "int32": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + if trans_b: + if shape_b[0] == 1: + return False + else: + if shape_b[1] == 1: + return False + if trans_a: + if trans_b: + if shape_a[0] != shape_b[1]: + return False + elif shape_a[0] != shape_b[0]: + return False + elif trans_b: + if shape_a[1] != shape_b[1]: + return False + elif shape_a[1] != shape_b[0]: + return False + + if trans_a_f and trans_b and shape_b[1] == 1: + return False + + if src_dtype == "float16": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + + if trans_a: + m_shape = shape_a[1] + k_shape = shape_a[0] + else: + m_shape = shape_a[0] + k_shape = shape_a[1] + + if trans_b: + n_shape = shape_b[0] + k_b_shape = shape_b[1] + else: + n_shape = shape_b[1] + k_b_shape = shape_b[0] + + if k_shape != k_b_shape: + return False + + if m_shape == 1 or n_shape == 1: + if k_shape % 256 != 0: + return False + + except RuntimeError as e: + return False + + return True + + +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +@op_info_register(matmul_cube_op_info) +def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """ + calculating matrix multiplication with bias, C = A*B + bias, support input + data with fractal format. + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + src_dtype: str + The data type of input, support "float32", "float16" + dst_dtype: str + The data type of output, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + is_fractal: bool + If True, the input data format of a and b must be fractal format + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + + Returns + ------- + None + """ + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + + if shape_a is not None: + if len(shape_a) < 2: + shape_a = input_x1.get("shape") + + if shape_b is not None: + if len(shape_b) < 2: + shape_b = input_x2.get("shape") + + shape_a = list(shape_a) + shape_b = list(shape_b) + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = _get_input_shape(shape_a) + shape_b = _get_input_shape(shape_b) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = [shape_a[1], shape_a[0]] + trans_a = bool(1 - trans_a) + + if input_x2.get("format") == "FRACTAL_NZ": + shape_b = [shape_b[1], shape_b[0]] + trans_b = bool(1 - trans_b) + + shape_bias = () + if bias is not None and bool(bias): + shape_bias = bias.get("shape") + shape_bias = list(shape_bias) + shape_bias = _get_bias(shape_bias) + + src_dtype = input_x1.get("dtype").lower() + dst_dtype = output_y.get("dtype").lower() + if src_dtype == "float32" or src_dtype == "int32": + matmul_vector_cce(shape_a, shape_b, src_dtype, trans_a, trans_b, shape_bias, kernel_name) + return + _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) + m_shape = shape_a[len(shape_a) - 2] + km_shape = shape_a[len(shape_a) - 1] + kn_shape = shape_b[len(shape_a) - 2] + n_shape = shape_b[len(shape_a) - 1] + + if src_dtype == "float16": + block_reduce = cce.BLOCK_REDUCE + + block_in = cce.BLOCK_IN + block_out = cce.BLOCK_OUT + + if trans_a and km_shape == 1: + block_in = cce.BLOCK_VECTOR + + if not trans_a and m_shape == 1: + block_in = cce.BLOCK_VECTOR + + if trans_b and kn_shape == 1: + block_out = cce.BLOCK_VECTOR + + if not trans_b and n_shape == 1: + block_out = cce.BLOCK_VECTOR + + if trans_a: + shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) + else: + shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) + + if trans_b: + shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) + else: + shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) + + if input_x1.get("format") == "FORMAT_FRACTAL_Z": + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + format_a = "fractal" + elif input_x1.get("format") == "FRACTAL_NZ": + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + format_a = "FRACTAL_NZ" + else: + shape_a_temp = (shape_a[len(shape_a) - 2], shape_a[len(shape_a) - 1]) + format_a = "ND" + + if input_x2.get("format") == "FORMAT_FRACTAL_Z": + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + format_b = "fractal" + elif input_x2.get("format") == "FRACTAL_NZ": + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + format_b = "FRACTAL_NZ" + else: + shape_b_temp = (shape_b[len(shape_b) - 2], shape_b[len(shape_b) - 1]) + format_b = "ND" + + tensor_bias = None + tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', + dtype=src_dtype) + tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', + dtype=src_dtype) + + if len(shape_bias) > 0: + tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', + dtype=dst_dtype) + result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, + format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) + + with tvm.target.cce(): + schedule = generic.auto_schedule(result) + + tensor_list = [tensor_a, tensor_b, result] + if len(shape_bias) > 0: + tensor_list = [tensor_a, tensor_b, tensor_bias, result] + + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(schedule, config) diff --git a/mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py b/mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py new file mode 100644 index 00000000000..0a3f41386b9 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusMatrixCombine""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_matrix_combine_op_info = TBERegOp("CusMatrixCombine") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matrixcombine.so") \ + .compute_cost(10) \ + .kernel_name("CusMatrixCombine") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_matrix_combine_op_info) +def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): + """CusMatrixCombine""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + split_dim = 128 + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + + blocks = 32 + matrix_dim = input_x_shape[0] * input_x_shape[1] + if input_x_shape[0] == 1 and input_x_shape[1] == 64: + tiling_dim = 2 + bs = 1 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[0, block_index * tiling_dim, 0], 0, 1, 16, 0, 0) + tik_instance.data_move(res[block_index * tiling_dim, 0], input_x_ub, 0, 1, 16, 0, 0) + else: + tiling_dim = 4 + bs = input_x_shape[0] + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub", + scope=tik.scope_ubuf) + zero = tik_instance.Scalar("float32") + zero.set_as(0.0) + with tik_instance.for_range(0, bs) as i: + repeat_real = tiling_dim * matrix_dim // 64 + if repeat_real <= 255: + tik_instance.vector_dup(64, input_x_ub, zero, repeat_real, 1, 8) + else: + repeat_1 = 255 + repeat_2 = repeat_real - 255 + tik_instance.vector_dup(64, input_x_ub, zero, repeat_1, 1, 8) + tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, repeat_2, 1, 8) + with tik_instance.for_range(0, tiling_dim) as j: + tik_instance.data_move(input_x_ub[j, split_dim * i], input_x[i, block_index * tiling_dim + j, 0], 0, + 1, 16, 0, 0) + tik_instance.data_move(res[i * split_dim + block_index * tiling_dim, 0], input_x_ub, 0, 1, + tiling_dim * matrix_dim * 4 // 32, 0, 0) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py b/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py new file mode 100644 index 00000000000..141e2c1d51b --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py @@ -0,0 +1,289 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusTranspose02314""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("transpose02314.so") \ + .compute_cost(10) \ + .kernel_name("CusTranspose02314") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(cus_transpose02314_op_info) +def CusTranspose02314(input_x, output, kernel_name="transpose021354"): + """CusTranspose02314""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + perm = (0, 2, 3, 1, 4) + input_x_shape = tuple(input_x_shape) + support_shape = [(32, 128, 7, 7, 16), + (32, 32, 7, 7, 16), + (32, 32, 14, 14, 16), + (32, 64, 14, 14, 16), + (32, 16, 14, 14, 16), + (32, 16, 28, 28, 16), + (32, 32, 28, 28, 16), + (32, 8, 28, 28, 16), + (32, 8, 56, 56, 16), + (32, 16, 56, 56, 16), + (32, 4, 56, 56, 16), + (32, 4, 112, 112, 16)] + if input_x_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_x_shape)) + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float16", output_shape, name="res", scope=tik.scope_gm) + + dtype = "float16" + if tuple(input_x_shape) == (32, 4, 112, 112, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 14) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + zero = tik_instance.Scalar(dtype="float16", init_value=0) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 802816 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448, + 12096, 0) + with tik_instance.for_range(0, 448) as cc7: + with tik_instance.for_range(0, 4) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16], + input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 4, 56, 56, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 3) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 200704 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448, + 2688, 0) + with tik_instance.for_range(0, 448) as cc7: + with tik_instance.for_range(0, 4) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16], + input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + + input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf) + T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 43008], 0, 4, 448, 2688, 0) + with tik_instance.for_range(0, 448) as cc72: + with tik_instance.for_range(0, 4) as cc82: + tik_instance.vadds(16, T_transpose_local_UB2[cc72 * 64 + cc82 * 16], + input_1_local_UB2[7168 * cc82 + cc72 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 16, 56, 56, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 14) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 802816 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, + 3024, 0) + with tik_instance.for_range(0, 112) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], + input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 8, 56, 56, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 7) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 401408 + cc1_db * 7168 + 3584 * db_idx], 0, 8, 224, 2912, + 0) + with tik_instance.for_range(0, 224) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16], + input_1_local_UB[3584 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 8, 28, 28, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 2) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 100352 + cc1_db * 6272 + 3136 * db_idx], 0, 8, 196, 588, + 0) + with tik_instance.for_range(0, 196) as cc7: + with tik_instance.for_range(0, 8) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16], + input_1_local_UB[3136 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 100352 + cc1_db * 50176 + 25088 * db_idx], + T_transpose_local_UB, 0, 1, 1568, 0, 0) + elif tuple(input_x_shape) == (32, 32, 28, 28, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 7) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input_x[block_idx * 401408 + cc1_db * 1792 + 896 * db_idx], + 0, 32, 56, 728, 0) + with tik_instance.for_range(0, 56) as cc7: + with tik_instance.for_range(0, 32) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 512 + cc8 * 16], + input_1_local_UB[896 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 16, 28, 28, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 3) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 200704 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, 672, + 0) + with tik_instance.for_range(0, 112) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], + input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + + input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf) + T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 10752], 0, 16, 112, 672, 0) + with tik_instance.for_range(0, 112) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB2[cc7 * 256 + cc8 * 16], + input_1_local_UB2[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0) + + elif tuple(input_x_shape) == (32, 16, 14, 14, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input_x[block_idx * 50176 + 1568 * db_idx], 0, 16, 98, 98, 0) + with tik_instance.for_range(0, 98) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], + input_1_local_UB[1568 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 50176 + 25088 * db_idx], T_transpose_local_UB, 0, 1, 1568, 0, 0) + elif tuple(input_x_shape) == (32, 128, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 7, thread_num=2) as cc1: + input_x_ub = tik_instance.Tensor(dtype, [1, 128, 1, 7, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 1, 7, 128, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, cc1, 0, 0], 0, 128, 7, 42, 0) + with tik_instance.for_range(0, 7) as cc7: + with tik_instance.for_range(0, 128) as cc8: + tik_instance.vadds(16, transpose_ub[0, 0, cc7, cc8, 0], input_x_ub[0, cc8, 0, cc7, 0], 0, + 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 100352 + 14336 * cc1], transpose_ub, 0, 1, 896, 0, 0) + + elif tuple(input_x_shape) == (32, 32, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + input_x_ub = tik_instance.Tensor(dtype, [1, 32, 7, 7, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 7, 7, 32, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, 0, 0, 0], 0, 1, 1568, 0, 0) + with tik_instance.for_range(0, 7) as cc1: + with tik_instance.for_range(0, 7) as cc2: + with tik_instance.for_range(0, 32) as cc3: + tik_instance.vadds(16, transpose_ub[0, cc1, cc2, cc3, 0], input_x_ub[0, cc3, cc1, cc2, 0], 0, + 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 25088], transpose_ub, 0, 1, 1568, 0, 0) + + elif tuple(input_x_shape) == (32, 32, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + def _inner_compute(split_index): + input_x_ub = tik_instance.Tensor(dtype, [1, 32, 2, 14, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 32, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 32, 28, 168, 0) + with tik_instance.for_range(0, 2) as cc2: + with tik_instance.for_range(0, 14) as cc3: + with tik_instance.for_range(0, 32) as cc4: + tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0], + 0, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 100352 + split_index * 2 * 7168], transpose_ub, 0, 1, 896, 0, 0) + + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 6, thread_num=2) as cc1: + _inner_compute(cc1) + _inner_compute(6) + elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + def _inner_compute(split_index, block_idx): + input_x_ub = tik_instance.Tensor(dtype, [1, 64, 2, 14, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 64, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 64, 28, 168, 0) + with tik_instance.for_range(0, 2) as cc2: + with tik_instance.for_range(0, 14) as cc3: + with tik_instance.for_range(0, 64) as cc4: + tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0], + 0, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + split_index * 2 * 14336], transpose_ub, 0, 1, 1792, 0, 0) + + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 6, thread_num=2) as cc1: + _inner_compute(cc1, block_idx) + _inner_compute(6, block_idx) + + tik_instance.BuildCCE(kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py deleted file mode 100644 index e2afa96a7d6..00000000000 --- a/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""batch_matmul_impl""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusBatchMatMul", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "batchmatmul.so", - "compute_cost": 10, - "kernel_name": "CusBatchMatMul", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): - """CusBatchMatMul""" - return diff --git a/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py b/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py deleted file mode 100644 index 5c38dfc25d4..00000000000 --- a/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusCholeskyTrsm""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusCholeskyTrsm", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "choleskytrsm.so", - "compute_cost": 10, - "kernel_name": "CusCholeskyTrsm", - "partial_flag": true, - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusCholeskyTrsm(input_x, output, kernel_name): - """CusCholeskyTrsm""" - return diff --git a/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py b/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py deleted file mode 100644 index b9a0d452738..00000000000 --- a/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusFusedAbsMax1""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusFusedAbsMax1", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "fusedabsmax1.so", - "compute_cost": 10, - "kernel_name": "CusFusedAbsMax1", - "partial_flag": true, - "attr": [ - { - "name": "origin_shape", - "param_type": "required", - "type": "listInt", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): - """CusFusedAbsMax1""" - return diff --git a/mindspore/ops/_op_impl/custom_op/img2col_impl.py b/mindspore/ops/_op_impl/custom_op/img2col_impl.py deleted file mode 100644 index 5137d4d7e70..00000000000 --- a/mindspore/ops/_op_impl/custom_op/img2col_impl.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusImg2ColNC1HWC0""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusImg2ColNC1HWC0", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "img2colnc1hwc0.so", - "compute_cost": 10, - "kernel_name": "CusImg2ColNC1HWC0", - "partial_flag": true, - "attr": [ - { - "name": "ksizes", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "strides", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "dilates", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "padding", - "param_type": "required", - "type": "str", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"): - """CusImg2ColNC1HWC0""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py deleted file mode 100644 index 300410eb4a3..00000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register -from topi.cce import util - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCubeDenseLeft", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcubedenseleft.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCubeDenseLeft", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) -def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, - kernel_name="matmulcube"): - """CusMatMulCubeDenseLeft""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py deleted file mode 100644 index 3da1593dfd3..00000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register -from topi.cce import util - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCubeFraczLeftCast", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcubefraczleftcast.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCubeFraczLeftCast", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "FracZ" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements -@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) -def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, - kernel_name="CusMatMulCubeFraczLeftCast"): - """CusMatMulCubeFraczLeftCast""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py deleted file mode 100644 index 7fc2ba35d16..00000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCubeFraczRightMul", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcubefraczrightmul.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCubeFraczRightMul", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 3, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x4", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "FracZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, - kernel_name="matmulcube"): - """CusMatMulCubeFraczRightMul""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py deleted file mode 100644 index 7c2d81e1d67..00000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register -from topi.cce import util - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCube", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcube.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCube", - "partial_flag": true, - "attr": [ - { - "name": "transpose_a", - "param_type": "required", - "type": "bool", - "value": "all" - }, - { - "name": "transpose_b", - "param_type": "required", - "type": "bool", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements -@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) -def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): - """CusMatMulCube""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py b/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py deleted file mode 100644 index 32045e7ccbd..00000000000 --- a/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusMatrixCombine""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusMatrixCombine", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matrixcombine.so", - "compute_cost": 10, - "kernel_name": "CusMatrixCombine", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): - """CusMatrixCombine""" - return diff --git a/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py b/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py deleted file mode 100644 index c5aebe523d5..00000000000 --- a/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusTranspose02314""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusTranspose02314", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "transpose02314.so", - "compute_cost": 10, - "kernel_name": "CusTranspose02314", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusTranspose02314(input_x, output, kernel_name="transpose021354"): - """CusTranspose02314""" - return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index fca4f57b719..5af72eb0393 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from . import _quant_ops from ._quant_ops import * +from .thor_ops import * __all__ = [ 'TensorAdd', @@ -262,5 +263,6 @@ __all__ = [ "SquareSumAll" ] +__all__.extend(thor_ops.__all__) __all__.extend(_quant_ops.__all__) __all__.sort() diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py index 23593a26305..5e6ff4b9599 100644 --- a/mindspore/ops/operations/thor_ops.py +++ b/mindspore/ops/operations/thor_ops.py @@ -17,13 +17,26 @@ import mindspore as ms from mindspore.ops import prim_attr_register, PrimitiveWithInfer from mindspore.ops.composite import multitype_ops as C +__all__ = ["CusBatchMatMul", + "CusCholeskyTrsm", + "CusFusedAbsMax1", + "CusImg2Col", + "CusMatMulCubeDenseLeft", + "CusMatMulCubeFraczRightMul", + "CusMatMulCube", + "CusMatrixCombine", + "CusTranspose02314", + "CusMatMulCubeDenseRight", + "CusMatMulCubeFraczLeftCast", + ] + class CusBatchMatMul(PrimitiveWithInfer): - """CusMatMulCube definition""" + """CusBatchMatMul definition""" @prim_attr_register def __init__(self): - """init CusMatMulCube""" + """init CusBatchMatMul""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) def get_bprop(self): @@ -61,11 +74,11 @@ class CusCholeskyTrsm(PrimitiveWithInfer): class CusFusedAbsMax1(PrimitiveWithInfer): - """CusCholeskyTrsm definition""" + """CusFusedAbsMax1 definition""" @prim_attr_register def __init__(self, origin_shape=[-1, -1]): - """init CusCholeskyTrsm""" + """init CusFusedAbsMax1""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.origin_shape = origin_shape @@ -126,7 +139,7 @@ class CusMatMulCubeDenseLeft(PrimitiveWithInfer): @prim_attr_register def __init__(self): - """init CusMatMulCube""" + """init CusMatMulCubeDenseLeft""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) def get_bprop(self): @@ -199,11 +212,11 @@ class CusMatMulCube(PrimitiveWithInfer): class CusMatrixCombine(PrimitiveWithInfer): - """CusMatMulCube definition""" + """CusMatrixCombine definition""" @prim_attr_register def __init__(self): - """init CusMatMulCube""" + """init CusMatrixCombine""" self.init_prim_io_names(inputs=['x'], outputs=['y']) def get_bprop(self): @@ -246,3 +259,45 @@ class CusTranspose02314(PrimitiveWithInfer): def infer_dtype(self, data1_dtype): return data1_dtype + + +class CusMatMulCubeDenseRight(PrimitiveWithInfer): + """CusMatMulCubeDenseRight definition""" + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeDenseRight""" + self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) + + def get_bprop(self): + def bprop(x1, x2, x3, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape, data3_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float32")) + + +class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): + """CusMatMulCubeFraczLeftCast definition""" + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeFraczLeftCast""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + + def get_bprop(self): + def bprop(x1, x2, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape): + return data2_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float16"))