!1527 THOR ops master -> r0.3

Merge pull request !1527 from zongha/r0.3
This commit is contained in:
mindspore-ci-bot 2020-05-27 20:02:31 +08:00 committed by Gitee
commit fac36e6a1a
33 changed files with 5218 additions and 922 deletions

View File

@ -23,7 +23,7 @@ config = ed({
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 50,
"epoch_size": 45,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
@ -31,15 +31,7 @@ config = ed({
"save_checkpoint_steps": 5004,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./",
"lr_init": 0.01,
"lr_end": 0.00001,
"lr_max": 0.1,
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"label_smooth": 1,
"label_smooth_factor": 0.1,
"lr": 0.1,
"T_max": 90,
"eta_min": 0,
"frequency": 278
"frequency": 834
})

View File

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
from dataset_imagenet import create_dataset
from config import config
from mindspore import context
from mindspore.model_zoo.resnet import resnet50
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
if __name__ == '__main__':
net = resnet50(class_num=config.class_num)
if not config.label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_eval:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -21,11 +21,6 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight
from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast
from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft
from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul
from model.grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@ -68,10 +63,10 @@ class THOR(Optimizer):
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.cube_matmul_left = CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul()
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()

View File

@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P
from cus_ops.cus_batch_matmul import CusBatchMatMul
from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm
from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1
from cus_ops.cus_img2col import CusImg2Col
from cus_ops.cus_matmul_cube import CusMatMulCube
from cus_ops.cus_matrix_combine import CusMatrixCombine
from cus_ops.cus_transpose02314 import CusTranspose02314
import numpy as np
C0 = 16
def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0)
if is_A:
@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv):
group=self.group
)
self.img2col = CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = CusMatMulCube(transpose_a=True)
self.matrix_combine = CusMatrixCombine()
self.cholesky = CusCholeskyTrsm()
self.transpose02314 = CusTranspose02314()
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
self.cholesky = P.CusCholeskyTrsm()
self.transpose02314 = P.CusTranspose02314()
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim,
@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv):
self.mul = P.Mul()
self.cast = P.Cast()
self.damping = Tensor(damping)
self.vector_matmul = CusBatchMatMul()
self.vector_matmul = P.CusBatchMatMul()
self.diag_block_dim = 128
self.channels_slice_flag = False
if self.in_channels % C0 != 0:
@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv):
self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32)
self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32)
self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim])
self.fused_abs_max2 = CusFusedAbsMax1()
self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
@ -375,9 +365,9 @@ class Dense_Thor(Cell):
self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16))
self.matmul = P.MatMul(transpose_b=True)
self.cube_matmul = CusMatMulCube(transpose_a=True)
self.matrix_combine = CusMatrixCombine()
self.cholesky = CusCholeskyTrsm()
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
self.cholesky = P.CusCholeskyTrsm()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
@ -386,7 +376,7 @@ class Dense_Thor(Cell):
self.cast = P.Cast()
self.damping = Tensor(damping)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.vector_matmul = CusBatchMatMul()
self.vector_matmul = P.CusBatchMatMul()
self.pad = P.Pad(((0, 24), (0, 24)))
self.pad1 = P.Pad(((0, 8), (0, 8)))
self.slice = P.Slice()
@ -396,8 +386,8 @@ class Dense_Thor(Cell):
self.axis = 0
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000])
self.fused_abs_max2 = CusFusedAbsMax1()
self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
self.dampingA = Tensor(np.identity(2048), mstype.float32)

View File

@ -45,8 +45,7 @@ do
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r second_order ./train_parallel$i/second_order
cp -r test_ops ./train_parallel$i/test_ops
cp -r model ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"

View File

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "infer" ];
then
rm -rf ./infer
fi
mkdir ./infer
cp *.py ./infer
cp *.sh ./infer
cd ./infer || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

View File

@ -109,7 +109,7 @@ if __name__ == '__main__':
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_model_lr(0, 0.05, 6, 70, 5004))
lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004))
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),

View File

@ -19,5 +19,6 @@ from .aicpu import *
if "Windows" not in platform.system():
from .akg.gpu import *
from .tbe import *
from ._custom_op import *
__all__ = []

View File

@ -0,0 +1,16 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""custom ops"""

View File

@ -0,0 +1,257 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""batch_matmul_impl"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tik
from topi.cce import util
cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batchmatmul.so") \
.compute_cost(10) \
.kernel_name("CusBatchMatMul") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
def _get_flattern_shape(shape):
"""_get_flattern_shape"""
flattern_shape = 1
for dim in shape:
flattern_shape *= dim
return (flattern_shape,)
def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index):
"""_inner_matmul_new"""
input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf)
t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0)
with tik_instance.for_range(0, 2) as vec_i:
tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, 64, 1, 1, 16, 0)
with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2:
input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB",
scope=tik.scope_ubuf)
t_1_local_UB = input_2_local_UB
bisec_last_axis_local_UB = input_2_local_UB
matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], name="matmul_hybrid_f_t_local_UB",
scope=tik.scope_ubuf)
matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64],
name="matmul_hybrid_f_t_local_UB_dst_tmp",
scope=tik.scope_ubuf)
tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8)
tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 8192], 0, 1, 1024, 0, 0)
tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8)
tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1,
16, 16, 16)
tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8)
with tik_instance.for_range(0, 64) as cc6:
tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], bisec_last_axis_local_UB[cc6 * 128],
1, 1, 1, 8)
tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp,
matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8)
tik_instance.data_move(res[res_index + thread_idx2 * 64],
matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0)
def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index):
"""_inner_matmul_new_1_64_32_64"""
input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0)
with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2:
input_2_local_UB = tik_instance.Tensor(dtype, [32 * 64], name="input_2_local_UB",
scope=tik.scope_ubuf)
t_1_local_UB = input_2_local_UB
matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [32], name="matmul_hybrid_f_t_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 2048], 0, 1, 256, 0, 0)
tik_instance.vmul(64, t_1_local_UB, input_1_local_UB, input_2_local_UB, 32, 1, 1, 1, 8, 0, 8)
with tik_instance.for_range(0, 32) as cc6:
tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB[cc6], t_1_local_UB[cc6 * 64],
1, 1, 1, 8)
tik_instance.data_move(res[res_index + thread_idx2 * 32],
matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0)
@op_info_register(cus_batchmatmul_op_info)
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"):
"""CusBatchMatMul"""
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
x1_shape = input_x1.get("shape")
dtype = input_x1.get("dtype").lower()
x2_shape = input_x2.get("shape")
if dtype != input_x2.get("dtype").lower():
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % (
dtype, input_x2.get("dtype").lower()))
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b)
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
((36, 128, 128), (36, 128, 128), "float32", False, True),
((5, 128, 128), (5, 128, 128), "float32", False, True),
((18, 128, 128), (18, 128, 128), "float32", False, True),
((16, 128, 128), (16, 128, 128), "float32", False, True),
((9, 128, 128), (9, 128, 128), "float32", False, True),
((1, 64, 64), (1, 64, 64), "float32", False, True),
((1, 128, 128), (1, 128, 128), "float32", False, True),
((4, 128, 128), (4, 128, 128), "float32", False, True),
((2, 128, 128), (2, 128, 128), "float32", False, True)]
if input_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
# if not transpose_a and transpose_b:
batch, m, k = x1_shape
input1_shape = _get_flattern_shape(x1_shape)
input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm)
input2_shape = _get_flattern_shape(x2_shape)
input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm)
output_shape = x1_shape
res_shape = _get_flattern_shape(output_shape)
res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm)
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True):
with tik_instance.for_range(0, 18, block_num=18) as block_idx:
with tik_instance.for_range(0, 2) as cc0:
with tik_instance.for_range(0, 128, thread_num=2) as cc1:
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
input2_index = block_idx * 32768 + cc0 * 16384
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
_inner_matmul_new(tik_instance, dtype,
input1, input1_index,
input2, input2_index,
res, res_index)
if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True):
with tik_instance.for_range(0, 30, block_num=30) as block_idx:
with tik_instance.for_range(0, 11) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as thread_idx:
with tik_instance.if_scope(((((block_idx % 6) * 22) + (cc1_db * 2) + thread_idx) < 128)):
input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB",
scope=tik.scope_ubuf)
t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB, input1[
(block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + thread_idx * 128], 0, 1,
16, 0, 0)
with tik_instance.for_range(0, 2) as vec_i:
tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0,
64, 1, 1, 16, 0)
with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2:
input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB",
scope=tik.scope_ubuf)
t_1_local_UB = input_2_local_UB
bisec_last_axis_local_UB = input_2_local_UB
matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64],
name="matmul_hybrid_f_t_local_UB",
scope=tik.scope_ubuf)
matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64],
name="matmul_hybrid_f_t_local_UB_dst_tmp",
scope=tik.scope_ubuf)
tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8)
tik_instance.data_move(input_2_local_UB,
input2[(block_idx // 6) * 16384 + thread_idx2 * 8192], 0, 1,
1024, 0, 0)
tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8)
tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1,
16, 16, 16)
tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8)
with tik_instance.for_range(0, 64) as cc6:
tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6],
bisec_last_axis_local_UB[cc6 * 128],
1, 1, 1, 8)
tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp,
matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8)
tik_instance.data_move(
res[(block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 +
thread_idx * 128 + thread_idx2 * 64],
matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0)
if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True):
with tik_instance.for_range(0, 18, block_num=18) as block_idx:
with tik_instance.for_range(0, 128, thread_num=2) as cc0:
input1_index = block_idx * 16384 + cc0 * 128
input2_index = block_idx * 16384
res_index = block_idx * 16384 + cc0 * 128
_inner_matmul_new(tik_instance, dtype,
input1, input1_index,
input2, input2_index,
res, res_index)
if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True):
with tik_instance.for_range(0, 27, block_num=27) as block_idx:
with tik_instance.for_range(0, 42, thread_num=2) as cc0:
input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128
input2_index = (block_idx // 3) * 16384
res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128
_inner_matmul_new(tik_instance, dtype,
input1, input1_index,
input2, input2_index,
res, res_index)
with tik_instance.if_scope((block_idx % 3) < 2):
input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128
input2_index = (block_idx // 3) * 16384
res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128
_inner_matmul_new(tik_instance, dtype,
input1, input1_index,
input2, input2_index,
res, res_index)
if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
with tik_instance.for_range(0, 2, thread_num=2) as cc0:
input1_index = block_idx * 128 + cc0 * 64
input2_index = 0
res_index = block_idx * 128 + cc0 * 64
_inner_matmul_new_1_64_32_64(tik_instance, dtype,
input1, input1_index,
input2, input2_index,
res, res_index)
input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True),
((2, 128, 128), (2, 128, 128), "float32", False, True),
((4, 128, 128), (4, 128, 128), "float32", False, True),
((8, 128, 128), (8, 128, 128), "float32", False, True),
((16, 128, 128), (16, 128, 128), "float32", False, True)
]
if input_shape in input_shape_list:
block_num = 32
input1_unit_size = 128
input2_unint_size = 128 * 128
with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx:
block_process_ele_num = (batch * m * k) // block_num
loop_time = (batch * m * k) // block_num // input1_unit_size
thread_num = 2
with tik_instance.for_range(0, loop_time, thread_num=thread_num) as cc0:
input1_index = block_idx * block_process_ele_num + cc0 * input1_unit_size
if batch > 1:
input2_index = block_idx // (block_num // batch) * input2_unint_size
else:
input2_index = 0
res_index = block_idx * block_process_ele_num + cc0 * input1_unit_size
_inner_matmul_new(tik_instance, dtype,
input1, input1_index,
input2, input2_index,
res, res_index)
tik_instance.BuildCCE(kernel_name, inputs=[input1, input2], outputs=[res])
return tik_instance

View File

@ -0,0 +1,111 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusCholeskyTrsm"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tik
from topi.cce import util
cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("choleskytrsm.so") \
.compute_cost(10) \
.kernel_name("CusCholeskyTrsm") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(cus_cholesky_trsm_op_info)
def CusCholeskyTrsm(input_x, output, kernel_name):
"""CusCholeskyTrsm"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
split_dim = 128
matrix_dim = input_x_shape[0]
split_dim = min(matrix_dim, split_dim)
vector_repeat_times = int(split_dim // 64)
blocks = int(matrix_dim // split_dim)
if blocks == 0:
blocks = 1
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="input_x_ub", scope=tik.scope_ubuf)
temp_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="temp_ub", scope=tik.scope_ubuf)
assist_1_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_1_ub", scope=tik.scope_ubuf)
assist_2_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_2_ub", scope=tik.scope_ubuf)
with tik_instance.for_range(0, split_dim) as i:
tik_instance.data_move(input_x_ub[i, 0], input_x[block_index * split_dim + i, block_index * split_dim], 0,
1, vector_repeat_times * 8, 0, 0)
scalar1 = tik_instance.Scalar("float32", init_value=-0.5)
with tik_instance.for_range(0, split_dim) as i:
scalar2 = tik_instance.Scalar("float32")
tik_instance.vln(64, assist_1_ub[0], input_x_ub[i, 0], vector_repeat_times, 1, 1, 8, 8)
tik_instance.vmuls(64, assist_2_ub[0], assist_1_ub[0], scalar1, vector_repeat_times, 1, 1, 8, 8)
tik_instance.vexp(64, assist_1_ub[0], assist_2_ub[0], vector_repeat_times, 1, 1, 8, 8)
scalar2.set_as(assist_1_ub[i])
tik_instance.vmuls(64, input_x_ub[i, 0], input_x_ub[i, 0], scalar2, vector_repeat_times, 1, 1, 8, 8)
with tik_instance.for_range(i + 1, split_dim) as j:
scalar3 = tik_instance.Scalar("float32")
scalar3.set_as(input_x_ub[i, j])
tik_instance.vmuls(64, temp_ub[j, 0], input_x_ub[i, 0], scalar3, vector_repeat_times, 1, 1, 8, 8)
tik_instance.vsub(64, input_x_ub[i + 1, 0], input_x_ub[i + 1, 0], temp_ub[i + 1, 0],
(split_dim - 1 - i) * vector_repeat_times, 1, 1, 1, 8, 8, 8)
zero = tik_instance.Scalar("float32")
zero.set_as(0.0)
one = tik_instance.Scalar("float32")
one.set_as(1.0)
with tik_instance.for_range(0, split_dim) as i:
tik_instance.vector_dup(64, temp_ub[i, 0], zero, vector_repeat_times, 1, 8)
temp_ub.__setitem__(i * split_dim + i, one)
chol_diag_element_final = tik_instance.Scalar("float32")
chol_diag_element_final.set_as(input_x_ub[split_dim * split_dim - 1])
trsm_diag_element = tik_instance.Scalar("float32")
trsm_diag_element.set_as(1.0 / chol_diag_element_final)
temp_ub.__setitem__(split_dim * split_dim - 1, trsm_diag_element)
with tik_instance.for_range(1, split_dim) as i:
index = split_dim - i - 1
tik_instance.vector_dup(64, assist_1_ub, zero, vector_repeat_times, 1, 8)
with tik_instance.for_range(0, i) as j:
chol_diag_element_loop = tik_instance.Scalar("float32")
chol_diag_element_loop.set_as(input_x_ub[index, index + 1 + j])
tik_instance.vmuls(64, assist_2_ub, temp_ub[j + index + 1, 0], chol_diag_element_loop,
vector_repeat_times, 1, 1, 8, 8)
tik_instance.vadd(64, assist_1_ub, assist_2_ub, assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, 8)
temp_scalar = tik_instance.Scalar("float32")
temp_scalar.set_as(input_x_ub[index, index])
chol_diag_element = tik_instance.Scalar("float32")
chol_diag_element.set_as(1.0 / temp_scalar)
tik_instance.vsub(64, temp_ub[index, 0], temp_ub[index, 0], assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8,
8)
tik_instance.vmuls(64, temp_ub[index, 0], temp_ub[index, 0], chol_diag_element, vector_repeat_times, 1, 1,
8, 8)
tik_instance.data_move(res[block_index, 0, 0], temp_ub, 0, 1, 8 * vector_repeat_times * split_dim, 0, 0)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res])
return tik_instance

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,468 @@
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
import te.lang.cce
import te.platform.cce_params as cce
from te import tik
from te import tvm
from topi import generic
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matmulcubedenseleft.so") \
.compute_cost(10) \
.kernel_name("CusMatMulCubeDenseLeft") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.get_op_info()
# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals,
def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b):
"""
Check the given input if legal
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
src_dtype: str
The data type of input, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
Returns None
"""
shape_len = len(shape_a)
src_dtype = src_dtype.lower()
k_block_size = cce.BLOCK_REDUCE
check_list = ("float16")
if src_dtype not in check_list:
raise RuntimeError("matmul_cce only support %s while src_dtype == %s"
% (",".join(check_list), src_dtype))
if shape_len != len(shape_b):
raise RuntimeError("length of a and b are not equal")
if shape_len != 2:
raise RuntimeError(
"length of shape must be 2, more than 2 dimensions should use batch_matmul now!")
is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False
is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False
if trans_a:
m_shape = shape_a[shape_len - 1]
km_shape = shape_a[shape_len - 2]
else:
m_shape = shape_a[shape_len - 2]
km_shape = shape_a[shape_len - 1]
if trans_b:
kn_shape = shape_b[shape_len - 1]
n_shape = shape_b[shape_len - 2]
else:
kn_shape = shape_b[shape_len - 2]
n_shape = shape_b[shape_len - 1]
if m_shape == 1:
if n_shape == 1:
raise RuntimeError("input shape M and N can't both be 1")
if km_shape != kn_shape:
print(km_shape, kn_shape)
raise RuntimeError("reduce axis not same")
if m_shape % cce.BLOCK_IN != 0 and m_shape != 1:
raise RuntimeError(
"input shape M should be 1 or multiple of %d" % cce.BLOCK_IN)
if m_shape != 1:
if n_shape == 1:
if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0:
raise RuntimeError("input shape K1 should be multiple of %d"
% (cce.BLOCK_IN * cce.BLOCK_IN))
elif km_shape % k_block_size != 0:
raise RuntimeError(
"input shape K1 should be multiple of %d" % cce.BLOCK_IN)
else:
if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0:
raise RuntimeError("input shape K1 should be multiple of %d"
% (cce.BLOCK_IN * cce.BLOCK_IN))
if n_shape % cce.BLOCK_IN != 0 and n_shape != 1:
raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN)
if len(shape_bias) != 0:
if len(shape_bias) == 1:
if is_gevm or is_gemv:
if shape_bias[0] != m_shape * n_shape:
raise RuntimeError("broadcast case shape bias for gemv must be equal m*n")
else:
if shape_bias[0] != n_shape:
raise RuntimeError("broadcast bias shape must be equal to shape n")
elif len(shape_bias) == shape_len:
if [i for i in shape_bias[-2:]] != [m_shape, n_shape]:
raise RuntimeError("non broadcast bias shape must be same as output shape")
else:
raise RuntimeError("unsupport input shape now for batch bias case")
def _get_bias(shape_bias):
"""_get_bias"""
bias_length = shape_bias[0]
shb = []
if bias_length % 16 == 0:
shb = shape_bias
else:
bias_length = (bias_length // 16) * 16 + 16
shape_bias = []
shape_bias.append(bias_length)
shb = shape_bias
return shb
def _get_input_shape(shape_x):
"""_get_input_shape"""
dim_a = shape_x[0]
dim_b = shape_x[1]
res = []
if dim_a % 16 != 0:
dim_a = (dim_a // 16) * 16 + 16
res.append(dim_a)
else:
res.append(dim_a)
if dim_b % 16 != 0:
dim_b = (dim_b // 16) * 16 + 16
res.append(dim_b)
else:
res.append(dim_b)
return res
def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"):
"""check_supported"""
shape_a = input_x1.get("shape")
shape_b = input_x2.get("shape")
print("shape_a: ", shape_a)
print("shape_b: ", shape_b)
src_dtype = input_x1.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
try:
trans_a_f = bool(1 - trans_a)
if src_dtype == "float32" or src_dtype == "int32":
if len(shape_a) != 2 and len(shape_b) != 2:
return False
if trans_b:
if shape_b[0] == 1:
return False
else:
if shape_b[1] == 1:
return False
if trans_a:
if trans_b:
if shape_a[0] != shape_b[1]:
return False
elif shape_a[0] != shape_b[0]:
return False
elif trans_b:
if shape_a[1] != shape_b[1]:
return False
elif shape_a[1] != shape_b[0]:
return False
if trans_a_f and trans_b and shape_b[1] == 1:
return False
if src_dtype == "float16":
if len(shape_a) != 2 and len(shape_b) != 2:
return False
if trans_a:
m_shape = shape_a[1]
k_shape = shape_a[0]
else:
m_shape = shape_a[0]
k_shape = shape_a[1]
if trans_b:
n_shape = shape_b[0]
k_b_shape = shape_b[1]
else:
n_shape = shape_b[1]
k_b_shape = shape_b[0]
if k_shape != k_b_shape:
return False
if m_shape == 1 or n_shape == 1:
if k_shape % 256 != 0:
return False
except RuntimeError as e:
return False
return True
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
# @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
@op_info_register(matmul_cube_dense_left_op_info)
def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""
calculating matrix multiplication with bias, C = A*B + bias, support input
data with fractal format.
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
src_dtype: str
The data type of input, support "float32", "float16"
dst_dtype: str
The data type of output, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
is_fractal: bool
If True, the input data format of a and b must be fractal format
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
Returns
-------
None
"""
print("!!!!come into zzt~~~~~~~!!!!")
shape_a = input_x1.get("ori_shape")
shape_b = input_x2.get("ori_shape")
shape_output = output_y.get("ori_shape")
print("============")
print(input_x1.get("format"), input_x2.get("format"))
print(shape_a, shape_b)
print("============")
if input_x2.get("format") == "FRACTAL_Z":
n, c, h, w = shape_b
c0 = 16
c1 = c // c0
if c1 == 0:
c1 = 1
shape_b = [n, c1 * h * w * c0]
shape_a = [n, n]
if input_x1.get("format") == "FRACTAL_Z":
n, c, h, w = shape_a
c0 = 16
c1 = c // c0
if c1 == 0:
c1 = 1
shape_a = [n, c1 * h * w * c0]
shape_b = [c1 * h * w * c0, c1 * h * w * c0]
if input_x2.get("format") == "FRACTAL_NZ":
shape_a = [shape_b[0], shape_b[0]]
shape_b = shape_b
if input_x1.get("format") == "FRACTAL_NZ":
shape_a = shape_a
shape_b = [shape_a[1], shape_a[1]]
shape_a = list(shape_a)
shape_b = list(shape_b)
shape_a = _get_input_shape(shape_a)
shape_b = _get_input_shape(shape_b)
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
shape_a = [shape_a[1], shape_a[0]]
trans_a = bool(1 - trans_a)
shape_b = [shape_b[1], shape_b[0]]
trans_b = bool(1 - trans_b)
shape_bias = ()
if bias is not None and bool(bias):
shape_bias = bias.get("shape")
shape_bias = list(shape_bias)
shape_bias = _get_bias(shape_bias)
src_dtype = input_x1.get("dtype").lower()
dst_dtype = output_y.get("dtype").lower()
_shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)
m_shape = shape_a[len(shape_a) - 2]
km_shape = shape_a[len(shape_a) - 1]
kn_shape = shape_b[len(shape_a) - 2]
n_shape = shape_b[len(shape_a) - 1]
if src_dtype == "float16":
block_reduce = cce.BLOCK_REDUCE
block_in = cce.BLOCK_IN
block_out = cce.BLOCK_OUT
if trans_a and km_shape == 1:
block_in = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in = cce.BLOCK_VECTOR
if trans_b and kn_shape == 1:
block_out = cce.BLOCK_VECTOR
if not trans_b and n_shape == 1:
block_out = cce.BLOCK_VECTOR
if trans_a:
shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in)
else:
shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce)
if trans_b:
shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out)
else:
shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce)
shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3])
format_a = "FRACTAL_NZ"
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
format_b = "FRACTAL_NZ"
print("=======================================")
print(shape_a_temp, shape_b_temp)
print(format_a, format_b)
print("=======================================")
tensor_bias = None
tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a',
dtype=src_dtype)
tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b',
dtype=src_dtype)
if len(shape_bias) > 0:
tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias',
dtype=dst_dtype)
if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 63:
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm)
resMatmul = tik_instance.Tensor("float16", shape_output, name="output", scope=tik.scope_gm)
with tik_instance.for_range(0, 32, block_num=32) as block_index:
resMatmul_local_UB = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_ubuf,
name="resMatmul_local_UB")
resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (128 * 256,), scope=tik.scope_cc,
name="resMatmul_local_UB")
input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_ca,
name="input_1_local_L1_local_L0A")
input_2_local_L1 = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cbuf,
name="input_2_local_L1")
input_1_local_L1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cbuf,
name="input_1_local_L1")
input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cb,
name="input_2_local_L1_local_L0B")
core_m_idx = block_index % 8
core_n_idx = block_index // 8
with tik_instance.if_scope(core_m_idx != 7):
tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 8, 128,
55 * 16, 0)
tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0,
32, 128, 55 * 16, 0)
with tik_instance.for_range(0, 8) as cc12:
tik_instance.load2dv1(input_1_local_L1_local_L0A[cc12 * 2048], input_1_local_L1[cc12 * 256], 0, 8,
8, 0, False)
with tik_instance.for_range(0, 2) as cc6:
with tik_instance.for_range(0, 8) as cc121:
tik_instance.load2dv1(input_2_local_L1_local_L0B[cc121 * 4096],
input_2_local_L1[cc6 * 32768 + cc121 * 256], 0, 16, 8, 0, True)
tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A,
input_2_local_L1_local_L0B, 128, 128, 256, 0)
tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0, 1)
tik_instance.data_move(resMatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008],
resMatmul_local_UB, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2)
with tik_instance.else_scope():
tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112,
56 * 16, 0)
tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0,
32, 112, 56 * 16, 0)
with tik_instance.for_range(0, 7) as cc10:
tik_instance.load2dv1(input_1_local_L1_local_L0A[cc10 * 1792], input_1_local_L1[cc10 * 256], 0, 7,
7, 0, False)
with tik_instance.for_range(0, 2) as cc5:
with tik_instance.for_range(0, 7) as cc101:
tik_instance.load2dv1(input_2_local_L1_local_L0B[cc101 * 4096],
input_2_local_L1[cc5 * 28672 + cc101 * 256], 0, 16, 7, 0, True)
tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A,
input_2_local_L1_local_L0B, 112, 112, 256, 0)
tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 112, 0, 0, 1)
tik_instance.data_move(resMatmul[cc5 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008],
resMatmul_local_UB, 0, 16, 224 // 2, 0, 56 * 16 * 2 // 2)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resMatmul])
return tik_instance
else:
print("come into tbe, shape is error!")
result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a,
format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias)
with tvm.target.cce():
schedule = generic.auto_schedule(result)
tensor_list = [tensor_a, tensor_b, result]
if len(shape_bias) > 0:
tensor_list = [tensor_a, tensor_b, tensor_bias, result]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(schedule, config)

View File

@ -0,0 +1,172 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tik
from topi.cce import util
matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matmulcubedenseright.so") \
.compute_cost(10) \
.kernel_name("CusMatMulCubeDenseRight") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.input(3, "x4", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default,
DataType.F32_FracNZ) \
.get_op_info()
@op_info_register(matmul_cube_dense_right_op_info)
def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""CusMatMulCubeDenseRight"""
shape_a_temp = (128, 63, 16, 16)
shape_b_temp = (128, 128, 16, 16)
shape_output = output_y.get("shape")
matrix_max_shape = (1,)
support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),]
shape_a_input = input_x1.get("shape")
shape_b_input = input_x2.get("shape")
matrix_max_input = input_x3.get("shape")
input_shape = (tuple(shape_a_input), tuple(shape_b_input), tuple(matrix_max_input))
if input_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
if shape_a_temp[0] == 128 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 128:
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm)
input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm)
resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm)
with tik_instance.for_range(0, 32, block_num=32) as block_index:
core_m_idx = block_index // 16
core_n_idx = block_index % 16
matrix_max_scalar = tik_instance.Scalar("float32")
matrix_max_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="matrix_max_local_UB")
tik_instance.data_move(matrix_max_local_UB, input_x3, 0, 1, 1, 0, 0)
matrix_max_scalar.set_as(matrix_max_local_UB[0])
resMatmul_local_UB = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf,
name="resMatmul_local_UB")
resMatmul_local_UB1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf,
name="resMatmul_local_UB1")
resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc,
name="resMatmul_local_UB_local_L0C")
resMatmul_local_UB_local_L0C1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc,
name="resMatmul_local_UB_local_L0C1")
input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca,
name="input_1_local_L1_local_L0A")
input_2_local_L1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_L1")
input_2_local_L11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf,
name="input_2_local_L11")
input_1_local_L1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf,
name="input_1_local_L1")
input_1_local_L11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf,
name="input_1_local_L11")
input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_L1_local_L0B")
input_2_local_L1_local_L0B1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb,
name="input_2_local_L1_local_L0B1")
with tik_instance.if_scope(core_m_idx == 0):
with tik_instance.for_range(0, 2) as cc1:
tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8,
128, 1920, 0)
tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256, 752,
0)
with tik_instance.for_range(0, 8) as cc10:
tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0,
8, 8, 0, True)
with tik_instance.for_range(0, 16) as cc101:
tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256],
0, 8, 16, 0, False)
tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A,
input_2_local_L1_local_L0B, 256, 128, 128, 0)
tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0)
tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64],
matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64],
matrix_max_scalar, 2, 1, 1, 8, 8)
tik_instance.data_move(resMatmul[core_n_idx * 129024 + cc1 * 4096], resMatmul_local_UB, 0, 8, 512,
0, 1504)
with tik_instance.else_scope():
tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
1920, 0)
tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0)
with tik_instance.for_range(0, 8) as cc10:
tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, 8,
8, 0, True)
with tik_instance.for_range(0, 16) as cc101:
tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], 0, 8,
16, 0, False)
tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, input_2_local_L1_local_L0B,
256, 128, 128, 0)
tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0)
tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], matrix_max_scalar,
255, 1, 1, 8, 8)
tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], matrix_max_scalar, 2,
1, 1, 8, 8)
tik_instance.data_move(resMatmul[core_n_idx * 129024 + 2 * 4096], resMatmul_local_UB, 0, 8, 512, 0,
1504)
tik_instance.data_move(input_2_local_L11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128,
1920, 0)
tik_instance.data_move(input_1_local_L11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0)
with tik_instance.for_range(0, 8) as cc102:
tik_instance.load2dv1(input_2_local_L1_local_L0B1[cc102 * 2048], input_2_local_L11[cc102 * 256], 0,
8, 8, 0, True)
with tik_instance.for_range(0, 16) as cc103:
tik_instance.load2dv1(input_1_local_L1_local_L0A[cc103 * 2048], input_1_local_L11[cc103 * 256], 0,
8, 15, 0, False)
tik_instance.mmad(resMatmul_local_UB_local_L0C1, input_1_local_L1_local_L0A,
input_2_local_L1_local_L0B1, 240, 128, 128, 0)
tik_instance.data_move(resMatmul_local_UB1, resMatmul_local_UB_local_L0C1, 0, 1, 120, 0, 0)
tik_instance.vmuls(64, resMatmul_local_UB1, resMatmul_local_UB1, matrix_max_scalar, 255, 1, 1, 8, 8)
tik_instance.vmuls(64, resMatmul_local_UB1[255 * 64], resMatmul_local_UB1[255 * 64], matrix_max_scalar,
225, 1, 1, 8, 8)
tik_instance.data_move(resMatmul[core_n_idx * 129024 + 12288], resMatmul_local_UB1, 0, 8, 480, 0, 1536)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul])
return tik_instance

View File

@ -0,0 +1,526 @@
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
import te.platform.cce_params as cce
from te import tik
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
matmul_cube_fracz_left_cast_op_info = TBERegOp("CusMatMulCubeFraczLeftCast") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matmulcubefraczleftcast.so") \
.compute_cost(10) \
.kernel_name("CusMatMulCubeFraczLeftCast") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F32_FracZ, DataType.F16_Default, DataType.F16_FracZ) \
.get_op_info()
# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals,
def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b):
"""
Check the given input if legal
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
src_dtype: str
The data type of input, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
Returns None
"""
shape_len = len(shape_a)
src_dtype = src_dtype.lower()
k_block_size = cce.BLOCK_REDUCE
check_list = ("float16")
if src_dtype not in check_list:
raise RuntimeError("matmul_cce only support %s while src_dtype == %s"
% (",".join(check_list), src_dtype))
if shape_len != len(shape_b):
raise RuntimeError("length of a and b are not equal")
if shape_len != 2:
raise RuntimeError(
"length of shape must be 2, more than 2 dimensions should use batch_matmul now!")
is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False
is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False
if trans_a:
m_shape = shape_a[shape_len - 1]
km_shape = shape_a[shape_len - 2]
else:
m_shape = shape_a[shape_len - 2]
km_shape = shape_a[shape_len - 1]
if trans_b:
kn_shape = shape_b[shape_len - 1]
n_shape = shape_b[shape_len - 2]
else:
kn_shape = shape_b[shape_len - 2]
n_shape = shape_b[shape_len - 1]
if m_shape == 1:
if n_shape == 1:
raise RuntimeError("input shape M and N can't both be 1")
if km_shape != kn_shape:
print(km_shape, kn_shape)
raise RuntimeError("reduce axis not same")
if m_shape % cce.BLOCK_IN != 0 and m_shape != 1:
raise RuntimeError(
"input shape M should be 1 or multiple of %d" % cce.BLOCK_IN)
if m_shape != 1:
if n_shape == 1:
if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0:
raise RuntimeError("input shape K1 should be multiple of %d"
% (cce.BLOCK_IN * cce.BLOCK_IN))
elif km_shape % k_block_size != 0:
raise RuntimeError(
"input shape K1 should be multiple of %d" % cce.BLOCK_IN)
else:
if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0:
raise RuntimeError("input shape K1 should be multiple of %d"
% (cce.BLOCK_IN * cce.BLOCK_IN))
if n_shape % cce.BLOCK_IN != 0 and n_shape != 1:
raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN)
if len(shape_bias):
if len(shape_bias) == 1:
if is_gevm or is_gemv:
if shape_bias[0] != m_shape * n_shape:
raise RuntimeError("broadcast case shape bias for gemv must be equal m*n")
else:
if shape_bias[0] != n_shape:
raise RuntimeError("broadcast bias shape must be equal to shape n")
elif len(shape_bias) == shape_len:
if [i for i in shape_bias[-2:]] != [m_shape, n_shape]:
raise RuntimeError("non broadcast bias shape must be same as output shape")
else:
raise RuntimeError("unsupport input shape now for batch bias case")
def _get_bias(shape_bias):
"""_get_bias"""
bias_length = shape_bias[0]
if bias_length % 16 == 0:
return shape_bias
else:
bias_length = (bias_length // 16) * 16 + 16
shape_bias = []
shape_bias.append(bias_length)
return shape_bias
def _get_input_shape(shape_x):
"""_get_input_shape"""
dim_a = shape_x[0]
dim_b = shape_x[1]
res = []
if dim_a % 16 != 0:
dim_a = (dim_a // 16) * 16 + 16
res.append(dim_a)
else:
res.append(dim_a)
if dim_b % 16 != 0:
dim_b = (dim_b // 16) * 16 + 16
res.append(dim_b)
else:
res.append(dim_b)
return res
def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"):
"""check_supported"""
shape_a = input_x1.get("shape")
shape_b = input_x2.get("shape")
print("shape_a: ", shape_a)
print("shape_b: ", shape_b)
src_dtype = input_x1.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
try:
trans_a_f = bool(1 - trans_a)
if src_dtype == "float32" or src_dtype == "int32":
if len(shape_a) != 2 and len(shape_b) != 2:
return False
if trans_b:
if shape_b[0] == 1:
return False
else:
if shape_b[1] == 1:
return False
if trans_a:
if trans_b:
if shape_a[0] != shape_b[1]:
return False
elif shape_a[0] != shape_b[0]:
return False
elif trans_b:
if shape_a[1] != shape_b[1]:
return False
elif shape_a[1] != shape_b[0]:
return False
if trans_a_f and trans_b and shape_b[1] == 1:
return False
if src_dtype == "float16":
if len(shape_a) != 2 and len(shape_b) != 2:
return False
if trans_a:
m_shape = shape_a[1]
k_shape = shape_a[0]
else:
m_shape = shape_a[0]
k_shape = shape_a[1]
if trans_b:
n_shape = shape_b[0]
k_b_shape = shape_b[1]
else:
n_shape = shape_b[1]
k_b_shape = shape_b[0]
if k_shape != k_b_shape:
return False
if m_shape == 1 or n_shape == 1:
if k_shape % 256 != 0:
return False
except RuntimeError as e:
return False
return True
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
@op_info_register(matmul_cube_fracz_left_cast_op_info)
def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="CusMatMulCubeFraczLeftCast"):
"""
calculating matrix multiplication with bias, C = A*B + bias, support input
data with fractal format.
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
src_dtype: str
The data type of input, support "float32", "float16"
dst_dtype: str
The data type of output, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
is_fractal: bool
If True, the input data format of a and b must be fractal format
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
Returns
-------
None
"""
shape_a = input_x1.get("ori_shape")
shape_b = input_x2.get("ori_shape")
print("============")
print(input_x1.get("format"), input_x2.get("format"))
print(shape_a, shape_b)
print("============")
if input_x2.get("format") == "FRACTAL_Z":
n, c, h, w = shape_b
c0 = 16
c1 = c // c0
if c1 == 0:
c1 = 1
shape_b = [n, c1 * h * w * c0]
shape_a = [n, n]
if input_x1.get("format") == "FRACTAL_Z":
n, c, h, w = shape_a
c0 = 16
c1 = c // c0
if c1 == 0:
c1 = 1
shape_a = [n, c1 * h * w * c0]
shape_b = [c1 * h * w * c0, c1 * h * w * c0]
if input_x2.get("format") == "FRACTAL_NZ":
shape_a = [shape_b[0], shape_b[0]]
shape_b = shape_b
if input_x1.get("format") == "FRACTAL_NZ":
shape_a = shape_a
shape_b = [shape_a[1], shape_a[1]]
shape_a = list(shape_a)
shape_b = list(shape_b)
shape_a = _get_input_shape(shape_a)
shape_b = _get_input_shape(shape_b)
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
shape_a = [shape_a[1], shape_a[0]]
trans_a = bool(1 - trans_a)
shape_b = [shape_b[1], shape_b[0]]
trans_b = bool(1 - trans_b)
shape_bias = ()
if bias is not None and bool(bias):
shape_bias = bias.get("shape")
shape_bias = list(shape_bias)
shape_bias = _get_bias(shape_bias)
src_dtype = input_x1.get("dtype").lower()
_shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)
m_shape = shape_a[len(shape_a) - 2]
km_shape = shape_a[len(shape_a) - 1]
kn_shape = shape_b[len(shape_a) - 2]
n_shape = shape_b[len(shape_a) - 1]
if src_dtype == "float16":
block_reduce = cce.BLOCK_REDUCE
block_in = cce.BLOCK_IN
block_out = cce.BLOCK_OUT
if trans_a and km_shape == 1:
block_in = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in = cce.BLOCK_VECTOR
if trans_b and kn_shape == 1:
block_out = cce.BLOCK_VECTOR
if not trans_b and n_shape == 1:
block_out = cce.BLOCK_VECTOR
if trans_a:
shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in)
else:
shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce)
if trans_b:
shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out)
else:
shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce)
shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3])
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_a_temp, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_b_temp, name="right_matrix", scope=tik.scope_gm)
res_matmul = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm)
DIAG_SIZE = 128
mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(input_x1, input_x2, DIAG_SIZE)
cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul,
mo_tile=mo_tile, ko_tile=ko_tile, no_tile=no_tile,
diag_opt=diag_opt, diag_size=DIAG_SIZE)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul])
return tik_instance
def get_cus_tile_info(input_x1, input_x2, diag_size):
"""get_cus_tile_info"""
tile_map = {
((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16),
((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4),
((32, 32, 16, 16), (288, 32, 16, 16)): (8, 8, 12),
((128, 128, 16, 16), (32, 128, 16, 16)): (8, 8, 16),
((16, 16, 16, 16), (144, 16, 16, 16)): (8, 8, 9),
((64, 64, 16, 16), (16, 64, 16, 16)): (8, 8, 4),
((16, 16, 16, 16), (64, 16, 16, 16)): (8, 8, 4),
((32, 32, 16, 16), (8, 32, 16, 16)): (8, 8, 1),
((128, 128, 16, 16), (64, 128, 16, 16)): (8, 8, 16),
((16, 16, 16, 16), (4, 16, 16, 16)): (8, 8, 1),
((16, 16, 16, 16), (32, 16, 16, 16)): (8, 8, 2),
((64, 64, 16, 16), (32, 64, 16, 16)): (8, 8, 8),
((32, 32, 16, 16), (64, 32, 16, 16)): (8, 8, 8),
((32, 32, 16, 16), (16, 32, 16, 16)): (8, 8, 2),
((8, 8, 16, 16), (32, 8, 16, 16)): (8, 8, 1),
((8, 8, 16, 16), (16, 8, 16, 16)): (4, 8, 1),
((4, 4, 16, 16), (16, 4, 16, 16)): (2, 4, 1),
((4, 4, 16, 16), (4, 4, 16, 16)): (1, 4, 1),
((4, 4, 16, 16), (36, 4, 16, 16)): (2, 4, 3),
((4, 4, 16, 16), (49, 4, 16, 16)): (1, 4, 7)
}
shape_info = (tuple(input_x1.shape), tuple(input_x2.shape))
diag_opt = False
if input_x1.shape[0] * input_x1.shape[3] > diag_size:
diag_opt = True
if shape_info not in tile_map:
raise ValueError("shape %s is not supported" % str(shape_info))
mo_tile, ko_tile, no_tile = tile_map[shape_info]
return mo_tile, ko_tile, no_tile, diag_opt
def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b,
res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128):
"""cus_cube_matmul_cast"""
ko, mo, _, _ = input_x1.shape
no, ko, _, _ = input_x2.shape
c0 = input_x1.shape[-1]
diag_outer = diag_size // c0
maxblocknum = 32
fp32_size = 4
fp16_size = 2
blocksize = 32
vectorfp32_size = 64
if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]:
raise ValueError("shape of input_x1 or input_x2 is not supported!")
if not trans_a or not trans_b:
raise ValueError("only trans_a=False and trans_b=False be supported!")
core_m_num = mo // mo_tile
loop_n_num = no // no_tile
if loop_n_num * core_m_num <= maxblocknum:
core_n_num = loop_n_num
else:
core_n_num = maxblocknum // core_m_num
if core_n_num > 0 and loop_n_num % core_n_num == 0:
loop_n_num = loop_n_num // core_n_num
else:
raise ValueError("Does not support this scenario!")
block_num = core_m_num * core_n_num
loop_k_num = ko // ko_tile
if diag_opt:
loop_k_num = diag_outer // ko_tile
# double buffer:
thread_num_k = 2
loop_k_num *= thread_num_k
ko_tile_inner = ko_tile // thread_num_k
with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx:
core_m = block_idx // core_n_num
core_n = block_idx % core_n_num
with tik_instance.for_range(0, loop_n_num) as cc_n:
res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
name="resMatmul_L0C", scope=tik.scope_cc)
with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k:
# input_x2 -> input_x2_ub -(fp322fp16)-> input_x2_cast_ub -> input_x2_L1
input_x2_ub = tik_instance.Tensor("float32", [no_tile, ko_tile_inner, c0, c0], name="input_x2_ub",
scope=tik.scope_ubuf)
if diag_opt:
k_idx = core_m * mo_tile + thread_idx_k * ko_tile_inner
else:
k_idx = thread_idx_k * ko_tile_inner
tik_instance.data_move(input_x2_ub,
input_x2[(core_n * loop_n_num + cc_n) * no_tile,
k_idx, 0, 0],
0, no_tile, ko_tile_inner * c0 * c0 * fp32_size // blocksize,
(ko - ko_tile_inner) * c0 * c0 * fp32_size // blocksize, 0)
input_x2_cast_ub = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0],
name="input_x2_cast_ub", scope=tik.scope_ubuf)
repeate_num = no_tile * ko_tile_inner * c0 * c0 // vectorfp32_size
repeate_times_max = 255
count = 0
while repeate_num > repeate_times_max:
tik_instance.vconv(vectorfp32_size, 'none',
input_x2_cast_ub[count * repeate_times_max * vectorfp32_size],
input_x2_ub[count * repeate_times_max * vectorfp32_size],
repeate_times_max,
1, 1, 4, 8)
repeate_num -= repeate_times_max
count += 1
tik_instance.vconv(vectorfp32_size, 'none',
input_x2_cast_ub[count * repeate_times_max * vectorfp32_size],
input_x2_ub[count * repeate_times_max * vectorfp32_size], repeate_num,
1, 1, 4, 8)
input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0],
name="input_x2_L1", scope=tik.scope_cbuf)
tik_instance.data_move(input_x2_L1, input_x2_cast_ub, 0, 1,
no_tile * ko_tile_inner * c0 * c0 * fp16_size // blocksize, 0, 0)
# input_x1 -> input_x1_L1
input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0],
name="input_x1_L1", scope=tik.scope_cbuf)
tik_instance.data_move(input_x1_L1,
input_x1[k_idx,
core_m * mo_tile, 0, 0],
0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize,
(mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0)
# input_x2_L1 -> input_x2_L0B
input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0],
name="input_x2_L0B", scope=tik.scope_cb)
with tik_instance.for_range(0, ko_tile_inner) as cc2:
tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile,
ko_tile_inner,
0, True)
# input_x1_L1 -> input_x1_L0A
input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0],
name="input_x1_L0A", scope=tik.scope_ca)
with tik_instance.for_range(0, mo_tile) as cc1:
tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner,
mo_tile, 0, False)
with tik_instance.if_scope(thread_idx_k == 0):
tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0,
ko_tile_inner * c0, no_tile * c0, 0)
with tik_instance.else_scope():
tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0,
ko_tile_inner * c0, no_tile * c0, 1)
res_ub = tik_instance.Tensor(input_x1.dtype, [no_tile, mo_tile, c0, c0],
name="resMatmul_ub", scope=tik.scope_ubuf)
tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0, 1)
tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, core_m * mo_tile, 0, 0],
res_ub, 0, no_tile,
mo_tile * c0 * c0 * fp16_size // blocksize, 0,
(mo - mo_tile) * c0 * c0 * fp16_size // blocksize)

View File

@ -0,0 +1,247 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tik
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matmulcubefraczrightmul.so") \
.compute_cost(10) \
.kernel_name("CusMatMulCubeFraczRightMul") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.input(3, "x4", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_FracZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default,
DataType.F32_FracZ) \
.get_op_info()
@op_info_register(cus_matmul_cube_fracz_right_mul_op_info)
def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""CusMatMulCubeFraczRightMul"""
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x1_shape = input_x1.get("shape")
input_x1_dtype = input_x1.get("dtype").lower()
input_x2_shape = input_x2.get("shape")
input_x2_dtype = input_x2.get("dtype").lower()
input_x3_shape = input_x3.get("shape")
input_x3_dtype = input_x3.get("dtype").lower()
output_shape = output_y.get("shape")
Supported = [((72, 8, 16, 16), "float16", (72, 72, 16, 16), "float16", (1,), "float32"),
((32, 8, 16, 16), "float16", (32, 32, 16, 16), "float16", (1,), "float32"),
((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"),
((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"),
((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'),
((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'),
((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'),
((64, 16, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'),
((32, 64, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'),
((32, 16, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'),
((16, 32, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'),
((16, 8, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'),
((16, 4, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'),
((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'),
((144, 16, 16, 16), 'float16', (144, 144, 16, 16), 'float16', (1,), 'float32'),
((128, 32, 16, 16), 'float16', (128, 128, 16, 16), 'float16', (1,), 'float32'),
((64, 128, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'),
((32, 128, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'),
((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'),
((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')]
input_shape = (
tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype)
if input_shape not in Supported:
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
input_x1 = tik_instance.Tensor("float16", input_x1_shape, name="left_matrix", scope=tik.scope_gm)
input_x2 = tik_instance.Tensor("float16", input_x2_shape, name="right_matrix", scope=tik.scope_gm)
input_x3 = tik_instance.Tensor("float32", input_x3_shape, name="matrix_max", scope=tik.scope_gm)
resMatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm)
cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resMatmul)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul])
return tik_instance
def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3,
res):
"""cus_cube_matmul_right_mul"""
diag_size = 128
ko, mo, _, _ = input_x1.shape
no, ko, _, _ = input_x2.shape
c0 = input_x1.shape[-1]
diag_outer = diag_size // c0
if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]:
raise ValueError("shape of input_x1 or input_x2 is not supported!")
def get_cus_tile_info(input_x1, input_x2, input_x3):
"""get_cus_tile_info"""
input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype,
tuple(input_x3.shape), input_x3.dtype)
tile_map = {
# no diag opt:
((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"): (4, 8, 2, 8, 4),
((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"): (1, 4, 1, 4, 4),
((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'): (1, 4, 2, 16, 2),
((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'): (1, 7, 7, 4, 7),
((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'): (2, 6, 3, 2, 12),
# diag opt:
((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'): (16, 8, 8, 2, 12),
}
maxblocknum = 32
diag_opt = False
if input_x2.shape[0] * input_x2.shape[3] > diag_size and input_x2.shape[0] % diag_outer == 0:
diag_opt = True
if input_shape in tile_map:
mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_ = tile_map[input_shape]
elif diag_opt:
ko_tile_ = diag_outer
no_tile_ = ko_tile_
core_n_num_ = no // no_tile_
core_m_num_max = maxblocknum // core_n_num_
mo_tile_ = -1
core_m_num_ = -1
for i in range(core_m_num_max, 0, -1):
if mo % i == 0:
core_m_num_ = i
mo_tile_ = mo // i
break
if mo_tile_ == -1:
raise ValueError("no valid tile be found!")
while mo_tile_ > 16:
mo_tile_ = mo_tile_ // 2
else:
raise ValueError("please add tile config to the tile_map")
print("shape: %s, tile: %s" % (input_shape, str((mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_,
diag_opt))))
return mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, diag_opt
mo_tile, ko_tile, no_tile, core_m_num, core_n_num, diag_opt = get_cus_tile_info(input_x1, input_x2, input_x3)
fp32_size = 4
fp16_size = 2
blocksize = 32
vectorfp32_size = 64
loop_n_num_total = no // no_tile
loop_m_num_total = mo // mo_tile
if loop_n_num_total % core_n_num != 0 or loop_m_num_total % core_m_num != 0:
raise ValueError("Does not support this scenario!")
loop_n_num = loop_n_num_total // core_n_num
loop_m_num = loop_m_num_total // core_m_num
block_num = core_n_num * core_m_num
loop_k_num = ko // ko_tile
if diag_opt:
loop_k_num = diag_outer // ko_tile
# double buffer:
thread_num_k = 2
if ko_tile % 2 == 0:
loop_k_num *= thread_num_k
ko_tile_inner = ko_tile // thread_num_k
else:
ko_tile_inner = ko_tile
ko_tile *= thread_num_k
with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx:
core_m = block_idx // core_n_num
core_n = block_idx % core_n_num
with tik_instance.for_range(0, loop_m_num) as cc_m:
with tik_instance.for_range(0, loop_n_num) as cc_n:
res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
name="resMatmul_L0C", scope=tik.scope_cc)
with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k:
if diag_opt:
k_idx = (core_n * loop_n_num + cc_n) * no_tile + thread_idx_k * ko_tile_inner
else:
k_idx = thread_idx_k * ko_tile_inner
# input_x1 -> input_x1_L1
input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0],
name="input_x1_L1", scope=tik.scope_cbuf)
tik_instance.data_move(input_x1_L1,
input_x1[k_idx,
(core_m * loop_m_num + cc_m) * mo_tile, 0, 0],
0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize,
(mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0)
# input_x2 -> input_x2_L1
input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0],
name="input_x2_L1", scope=tik.scope_cbuf)
tik_instance.data_move(input_x2_L1,
input_x2[(core_n * loop_n_num + cc_n) * no_tile,
k_idx, 0, 0],
0, no_tile, ko_tile_inner * c0 * c0 * fp16_size // blocksize,
(ko - ko_tile_inner) * c0 * c0 * fp16_size // blocksize, 0)
# input_x1_L1 -> input_x1_L0A
input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0],
name="input_x1_L0A", scope=tik.scope_ca)
with tik_instance.for_range(0, mo_tile) as cc1:
tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner,
mo_tile, 0, False)
# input_x2_L1 -> input_x2_L0B
input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0],
name="input_x2_L0B", scope=tik.scope_cb)
with tik_instance.for_range(0, ko_tile_inner) as cc2:
tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile,
ko_tile_inner,
0, True)
with tik_instance.if_scope(thread_idx_k == 0):
tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0,
ko_tile_inner * c0, no_tile * c0, 0)
with tik_instance.else_scope():
tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0,
ko_tile_inner * c0, no_tile * c0, 1)
res_ub = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0],
name="resMatmul_ub", scope=tik.scope_ubuf)
tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0)
input_3_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="input_3_local_UB")
tik_instance.data_move(input_3_local_UB, input_x3, 0, 1, 1, 0, 0)
matrix_max_scalar = tik_instance.Scalar("float32")
matrix_max_scalar.set_as(input_3_local_UB[0])
repeate_num = no_tile * mo_tile * c0 * c0 // vectorfp32_size
repeate_times_max = 255
count = 0
while repeate_num > repeate_times_max:
tik_instance.vmuls(vectorfp32_size,
res_ub[count * repeate_times_max * vectorfp32_size],
res_ub[count * repeate_times_max * vectorfp32_size],
matrix_max_scalar, repeate_times_max, 1, 1, 8, 8)
repeate_num -= repeate_times_max
count += 1
tik_instance.vmuls(vectorfp32_size,
res_ub[count * repeate_times_max * vectorfp32_size],
res_ub[count * repeate_times_max * vectorfp32_size],
matrix_max_scalar, repeate_num, 1, 1, 8, 8)
tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile,
(core_m * loop_m_num + cc_m) * mo_tile, 0, 0],
res_ub, 0, no_tile,
mo_tile * c0 * c0 * fp32_size // blocksize, 0,
(mo - mo_tile) * c0 * c0 * fp32_size // blocksize)

View File

@ -0,0 +1,397 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from impl.matmul_vector import matmul_vector_cce
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
import te.lang.cce
import te.platform.cce_params as cce
from te import tvm
from topi import generic
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
matmul_cube_op_info = TBERegOp("CusMatMulCube") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matmulcube.so") \
.compute_cost(10) \
.kernel_name("CusMatMulCube") \
.partial_flag(True) \
.attr("transpose_a", "required", "bool", "all") \
.attr("transpose_b", "required", "bool", "all") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_FracNZ) \
.get_op_info()
# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals,
def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b):
"""
Check the given input if legal
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
src_dtype: str
The data type of input, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
Returns None
"""
shape_len = len(shape_a)
src_dtype = src_dtype.lower()
k_block_size = cce.BLOCK_REDUCE
check_list = ("float16")
if src_dtype not in check_list:
raise RuntimeError("matmul_cce only support %s while src_dtype == %s"
% (",".join(check_list), src_dtype))
if shape_len != len(shape_b):
raise RuntimeError("length of a and b are not equal")
if shape_len != 2:
raise RuntimeError(
"length of shape must be 2, more than 2 dimensions should use batch_matmul now!")
is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False
is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False
if trans_a:
m_shape = shape_a[shape_len - 1]
km_shape = shape_a[shape_len - 2]
else:
m_shape = shape_a[shape_len - 2]
km_shape = shape_a[shape_len - 1]
if trans_b:
kn_shape = shape_b[shape_len - 1]
n_shape = shape_b[shape_len - 2]
else:
kn_shape = shape_b[shape_len - 2]
n_shape = shape_b[shape_len - 1]
if m_shape == 1:
if n_shape == 1:
raise RuntimeError("input shape M and N can't both be 1")
if km_shape != kn_shape:
raise RuntimeError("reduce axis not same")
if m_shape % cce.BLOCK_IN != 0 and m_shape != 1:
raise RuntimeError(
"input shape M should be 1 or multiple of %d" % cce.BLOCK_IN)
if m_shape != 1:
if n_shape == 1:
if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0:
raise RuntimeError("input shape K1 should be multiple of %d"
% (cce.BLOCK_IN * cce.BLOCK_IN))
elif km_shape % k_block_size != 0:
raise RuntimeError(
"input shape K1 should be multiple of %d" % cce.BLOCK_IN)
else:
if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0:
raise RuntimeError("input shape K1 should be multiple of %d"
% (cce.BLOCK_IN * cce.BLOCK_IN))
if n_shape % cce.BLOCK_IN != 0 and n_shape != 1:
raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN)
if len(shape_bias):
if len(shape_bias) == 1:
if is_gevm or is_gemv:
if shape_bias[0] != m_shape * n_shape:
raise RuntimeError("broadcast case shape bias for gemv must be equal m*n")
else:
if shape_bias[0] != n_shape:
raise RuntimeError("broadcast bias shape must be equal to shape n")
elif len(shape_bias) == shape_len:
if [i for i in shape_bias[-2:]] != [m_shape, n_shape]:
raise RuntimeError("non broadcast bias shape must be same as output shape")
else:
raise RuntimeError("unsupport input shape now for batch bias case")
def _get_bias(shape_bias):
"""_get_bias"""
bias_length = shape_bias[0]
if bias_length % 16 == 0:
return shape_bias
else:
bias_length = (bias_length // 16) * 16 + 16
shape_bias = []
shape_bias.append(bias_length)
return shape_bias
def _get_input_shape(shape_x):
"""_get_input_shape"""
dim_a = shape_x[0]
dim_b = shape_x[1]
res = []
if dim_a % 16 != 0:
dim_a = (dim_a // 16) * 16 + 16
res.append(dim_a)
else:
res.append(dim_a)
if dim_b % 16 != 0:
dim_b = (dim_b // 16) * 16 + 16
res.append(dim_b)
else:
res.append(dim_b)
return res
def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"):
"""check_supported"""
shape_a = input_x1.get("shape")
shape_b = input_x2.get("shape")
print("shape_a: ", shape_a)
print("shape_b: ", shape_b)
src_dtype = input_x1.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
try:
trans_a_f = bool(1 - trans_a)
if src_dtype == "float32" or src_dtype == "int32":
if len(shape_a) != 2 and len(shape_b) != 2:
return False
if trans_b:
if shape_b[0] == 1:
return False
else:
if shape_b[1] == 1:
return False
if trans_a:
if trans_b:
if shape_a[0] != shape_b[1]:
return False
elif shape_a[0] != shape_b[0]:
return False
elif trans_b:
if shape_a[1] != shape_b[1]:
return False
elif shape_a[1] != shape_b[0]:
return False
if trans_a_f and trans_b and shape_b[1] == 1:
return False
if src_dtype == "float16":
if len(shape_a) != 2 and len(shape_b) != 2:
return False
if trans_a:
m_shape = shape_a[1]
k_shape = shape_a[0]
else:
m_shape = shape_a[0]
k_shape = shape_a[1]
if trans_b:
n_shape = shape_b[0]
k_b_shape = shape_b[1]
else:
n_shape = shape_b[1]
k_b_shape = shape_b[0]
if k_shape != k_b_shape:
return False
if m_shape == 1 or n_shape == 1:
if k_shape % 256 != 0:
return False
except RuntimeError as e:
return False
return True
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
@op_info_register(matmul_cube_op_info)
def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"):
"""
calculating matrix multiplication with bias, C = A*B + bias, support input
data with fractal format.
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
src_dtype: str
The data type of input, support "float32", "float16"
dst_dtype: str
The data type of output, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
is_fractal: bool
If True, the input data format of a and b must be fractal format
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
Returns
-------
None
"""
shape_a = input_x1.get("ori_shape")
shape_b = input_x2.get("ori_shape")
if shape_a is not None:
if len(shape_a) < 2:
shape_a = input_x1.get("shape")
if shape_b is not None:
if len(shape_b) < 2:
shape_b = input_x2.get("shape")
shape_a = list(shape_a)
shape_b = list(shape_b)
if input_x1.get("format") == "FRACTAL_NZ":
shape_a = _get_input_shape(shape_a)
shape_b = _get_input_shape(shape_b)
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
if input_x1.get("format") == "FRACTAL_NZ":
shape_a = [shape_a[1], shape_a[0]]
trans_a = bool(1 - trans_a)
if input_x2.get("format") == "FRACTAL_NZ":
shape_b = [shape_b[1], shape_b[0]]
trans_b = bool(1 - trans_b)
shape_bias = ()
if bias is not None and bool(bias):
shape_bias = bias.get("shape")
shape_bias = list(shape_bias)
shape_bias = _get_bias(shape_bias)
src_dtype = input_x1.get("dtype").lower()
dst_dtype = output_y.get("dtype").lower()
if src_dtype == "float32" or src_dtype == "int32":
matmul_vector_cce(shape_a, shape_b, src_dtype, trans_a, trans_b, shape_bias, kernel_name)
return
_shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)
m_shape = shape_a[len(shape_a) - 2]
km_shape = shape_a[len(shape_a) - 1]
kn_shape = shape_b[len(shape_a) - 2]
n_shape = shape_b[len(shape_a) - 1]
if src_dtype == "float16":
block_reduce = cce.BLOCK_REDUCE
block_in = cce.BLOCK_IN
block_out = cce.BLOCK_OUT
if trans_a and km_shape == 1:
block_in = cce.BLOCK_VECTOR
if not trans_a and m_shape == 1:
block_in = cce.BLOCK_VECTOR
if trans_b and kn_shape == 1:
block_out = cce.BLOCK_VECTOR
if not trans_b and n_shape == 1:
block_out = cce.BLOCK_VECTOR
if trans_a:
shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in)
else:
shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce)
if trans_b:
shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out)
else:
shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce)
if input_x1.get("format") == "FORMAT_FRACTAL_Z":
shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3])
format_a = "fractal"
elif input_x1.get("format") == "FRACTAL_NZ":
shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3])
format_a = "FRACTAL_NZ"
else:
shape_a_temp = (shape_a[len(shape_a) - 2], shape_a[len(shape_a) - 1])
format_a = "ND"
if input_x2.get("format") == "FORMAT_FRACTAL_Z":
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
format_b = "fractal"
elif input_x2.get("format") == "FRACTAL_NZ":
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
format_b = "FRACTAL_NZ"
else:
shape_b_temp = (shape_b[len(shape_b) - 2], shape_b[len(shape_b) - 1])
format_b = "ND"
tensor_bias = None
tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a',
dtype=src_dtype)
tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b',
dtype=src_dtype)
if len(shape_bias) > 0:
tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias',
dtype=dst_dtype)
result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a,
format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias)
with tvm.target.cce():
schedule = generic.auto_schedule(result)
tensor_list = [tensor_a, tensor_b, result]
if len(shape_bias) > 0:
tensor_list = [tensor_a, tensor_b, tensor_bias, result]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(schedule, config)

View File

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusMatrixCombine"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tik
from topi.cce import util
cus_matrix_combine_op_info = TBERegOp("CusMatrixCombine") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matrixcombine.so") \
.compute_cost(10) \
.kernel_name("CusMatrixCombine") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(cus_matrix_combine_op_info)
def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"):
"""CusMatrixCombine"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
split_dim = 128
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
blocks = 32
matrix_dim = input_x_shape[0] * input_x_shape[1]
if input_x_shape[0] == 1 and input_x_shape[1] == 64:
tiling_dim = 2
bs = 1
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[0, block_index * tiling_dim, 0], 0, 1, 16, 0, 0)
tik_instance.data_move(res[block_index * tiling_dim, 0], input_x_ub, 0, 1, 16, 0, 0)
else:
tiling_dim = 4
bs = input_x_shape[0]
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub",
scope=tik.scope_ubuf)
zero = tik_instance.Scalar("float32")
zero.set_as(0.0)
with tik_instance.for_range(0, bs) as i:
repeat_real = tiling_dim * matrix_dim // 64
if repeat_real <= 255:
tik_instance.vector_dup(64, input_x_ub, zero, repeat_real, 1, 8)
else:
repeat_1 = 255
repeat_2 = repeat_real - 255
tik_instance.vector_dup(64, input_x_ub, zero, repeat_1, 1, 8)
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, repeat_2, 1, 8)
with tik_instance.for_range(0, tiling_dim) as j:
tik_instance.data_move(input_x_ub[j, split_dim * i], input_x[i, block_index * tiling_dim + j, 0], 0,
1, 16, 0, 0)
tik_instance.data_move(res[i * split_dim + block_index * tiling_dim, 0], input_x_ub, 0, 1,
tiling_dim * matrix_dim * 4 // 32, 0, 0)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res])
return tik_instance

View File

@ -0,0 +1,289 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusTranspose02314"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
from te import tik
from topi.cce import util
cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("transpose02314.so") \
.compute_cost(10) \
.kernel_name("CusTranspose02314") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_Default) \
.get_op_info()
@op_info_register(cus_transpose02314_op_info)
def CusTranspose02314(input_x, output, kernel_name="transpose021354"):
"""CusTranspose02314"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
perm = (0, 2, 3, 1, 4)
input_x_shape = tuple(input_x_shape)
support_shape = [(32, 128, 7, 7, 16),
(32, 32, 7, 7, 16),
(32, 32, 14, 14, 16),
(32, 64, 14, 14, 16),
(32, 16, 14, 14, 16),
(32, 16, 28, 28, 16),
(32, 32, 28, 28, 16),
(32, 8, 28, 28, 16),
(32, 8, 56, 56, 16),
(32, 16, 56, 56, 16),
(32, 4, 56, 56, 16),
(32, 4, 112, 112, 16)]
if input_x_shape not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_x_shape))
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float16", output_shape, name="res", scope=tik.scope_gm)
dtype = "float16"
if tuple(input_x_shape) == (32, 4, 112, 112, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
with tik_instance.for_range(0, 14) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
zero = tik_instance.Scalar(dtype="float16", init_value=0)
tik_instance.data_move(input_1_local_UB,
input_x[block_idx * 802816 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448,
12096, 0)
with tik_instance.for_range(0, 448) as cc7:
with tik_instance.for_range(0, 4) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16],
input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx],
T_transpose_local_UB, 0, 1, 1792, 0, 0)
elif tuple(input_x_shape) == (32, 4, 56, 56, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 3) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB,
input_x[block_idx * 200704 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448,
2688, 0)
with tik_instance.for_range(0, 448) as cc7:
with tik_instance.for_range(0, 4) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16],
input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx],
T_transpose_local_UB, 0, 1, 1792, 0, 0)
input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf)
T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 43008], 0, 4, 448, 2688, 0)
with tik_instance.for_range(0, 448) as cc72:
with tik_instance.for_range(0, 4) as cc82:
tik_instance.vadds(16, T_transpose_local_UB2[cc72 * 64 + cc82 * 16],
input_1_local_UB2[7168 * cc82 + cc72 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0)
elif tuple(input_x_shape) == (32, 16, 56, 56, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 14) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB,
input_x[block_idx * 802816 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112,
3024, 0)
with tik_instance.for_range(0, 112) as cc7:
with tik_instance.for_range(0, 16) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16],
input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx],
T_transpose_local_UB, 0, 1, 1792, 0, 0)
elif tuple(input_x_shape) == (32, 8, 56, 56, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 7) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB,
input_x[block_idx * 401408 + cc1_db * 7168 + 3584 * db_idx], 0, 8, 224, 2912,
0)
with tik_instance.for_range(0, 224) as cc7:
with tik_instance.for_range(0, 16) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16],
input_1_local_UB[3584 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx],
T_transpose_local_UB, 0, 1, 1792, 0, 0)
elif tuple(input_x_shape) == (32, 8, 28, 28, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 2) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB,
input_x[block_idx * 100352 + cc1_db * 6272 + 3136 * db_idx], 0, 8, 196, 588,
0)
with tik_instance.for_range(0, 196) as cc7:
with tik_instance.for_range(0, 8) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16],
input_1_local_UB[3136 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 100352 + cc1_db * 50176 + 25088 * db_idx],
T_transpose_local_UB, 0, 1, 1568, 0, 0)
elif tuple(input_x_shape) == (32, 32, 28, 28, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 7) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB, input_x[block_idx * 401408 + cc1_db * 1792 + 896 * db_idx],
0, 32, 56, 728, 0)
with tik_instance.for_range(0, 56) as cc7:
with tik_instance.for_range(0, 32) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 512 + cc8 * 16],
input_1_local_UB[896 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx],
T_transpose_local_UB, 0, 1, 1792, 0, 0)
elif tuple(input_x_shape) == (32, 16, 28, 28, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 3) as cc1_db:
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB",
scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB,
input_x[block_idx * 200704 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, 672,
0)
with tik_instance.for_range(0, 112) as cc7:
with tik_instance.for_range(0, 16) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16],
input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx],
T_transpose_local_UB, 0, 1, 1792, 0, 0)
input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf)
T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 10752], 0, 16, 112, 672, 0)
with tik_instance.for_range(0, 112) as cc7:
with tik_instance.for_range(0, 16) as cc8:
tik_instance.vadds(16, T_transpose_local_UB2[cc7 * 256 + cc8 * 16],
input_1_local_UB2[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0)
elif tuple(input_x_shape) == (32, 16, 14, 14, 16):
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
zero = tik_instance.Scalar(dtype="float16", init_value=0)
with tik_instance.for_range(0, 2, thread_num=2) as db_idx:
input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", scope=tik.scope_ubuf)
T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_1_local_UB, input_x[block_idx * 50176 + 1568 * db_idx], 0, 16, 98, 98, 0)
with tik_instance.for_range(0, 98) as cc7:
with tik_instance.for_range(0, 16) as cc8:
tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16],
input_1_local_UB[1568 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 50176 + 25088 * db_idx], T_transpose_local_UB, 0, 1, 1568, 0, 0)
elif tuple(input_x_shape) == (32, 128, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16":
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
with tik_instance.for_range(0, 7, thread_num=2) as cc1:
input_x_ub = tik_instance.Tensor(dtype, [1, 128, 1, 7, 16], name="input_1_local_UB",
scope=tik.scope_ubuf)
transpose_ub = tik_instance.Tensor(dtype, [1, 1, 7, 128, 16], name="transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[block_idx, 0, cc1, 0, 0], 0, 128, 7, 42, 0)
with tik_instance.for_range(0, 7) as cc7:
with tik_instance.for_range(0, 128) as cc8:
tik_instance.vadds(16, transpose_ub[0, 0, cc7, cc8, 0], input_x_ub[0, cc8, 0, cc7, 0], 0,
1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 100352 + 14336 * cc1], transpose_ub, 0, 1, 896, 0, 0)
elif tuple(input_x_shape) == (32, 32, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16":
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
input_x_ub = tik_instance.Tensor(dtype, [1, 32, 7, 7, 16], name="input_1_local_UB",
scope=tik.scope_ubuf)
transpose_ub = tik_instance.Tensor(dtype, [1, 7, 7, 32, 16], name="transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[block_idx, 0, 0, 0, 0], 0, 1, 1568, 0, 0)
with tik_instance.for_range(0, 7) as cc1:
with tik_instance.for_range(0, 7) as cc2:
with tik_instance.for_range(0, 32) as cc3:
tik_instance.vadds(16, transpose_ub[0, cc1, cc2, cc3, 0], input_x_ub[0, cc3, cc1, cc2, 0], 0,
1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 25088], transpose_ub, 0, 1, 1568, 0, 0)
elif tuple(input_x_shape) == (32, 32, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16":
def _inner_compute(split_index):
input_x_ub = tik_instance.Tensor(dtype, [1, 32, 2, 14, 16], name="input_1_local_UB",
scope=tik.scope_ubuf)
transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 32, 16], name="transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 32, 28, 168, 0)
with tik_instance.for_range(0, 2) as cc2:
with tik_instance.for_range(0, 14) as cc3:
with tik_instance.for_range(0, 32) as cc4:
tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0],
0, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 100352 + split_index * 2 * 7168], transpose_ub, 0, 1, 896, 0, 0)
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
with tik_instance.for_range(0, 6, thread_num=2) as cc1:
_inner_compute(cc1)
_inner_compute(6)
elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16":
def _inner_compute(split_index, block_idx):
input_x_ub = tik_instance.Tensor(dtype, [1, 64, 2, 14, 16], name="input_1_local_UB",
scope=tik.scope_ubuf)
transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 64, 16], name="transpose_local_UB",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 64, 28, 168, 0)
with tik_instance.for_range(0, 2) as cc2:
with tik_instance.for_range(0, 14) as cc3:
with tik_instance.for_range(0, 64) as cc4:
tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0],
0, 1, 1, 1, 0, 0)
tik_instance.data_move(res[block_idx * 200704 + split_index * 2 * 14336], transpose_ub, 0, 1, 1792, 0, 0)
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
with tik_instance.for_range(0, 6, thread_num=2) as cc1:
_inner_compute(cc1, block_idx)
_inner_compute(6, block_idx)
tik_instance.BuildCCE(kernel_name, inputs=[input_x], outputs=[res])
return tik_instance

View File

@ -1,76 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""batch_matmul_impl"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusBatchMatMul",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "batchmatmul.so",
"compute_cost": 10,
"kernel_name": "CusBatchMatMul",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"):
"""CusBatchMatMul"""
return

View File

@ -1,64 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusCholeskyTrsm"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusCholeskyTrsm",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "choleskytrsm.so",
"compute_cost": 10,
"kernel_name": "CusCholeskyTrsm",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusCholeskyTrsm(input_x, output, kernel_name):
"""CusCholeskyTrsm"""
return

View File

@ -1,69 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusFusedAbsMax1"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusFusedAbsMax1",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "fusedabsmax1.so",
"compute_cost": 10,
"kernel_name": "CusFusedAbsMax1",
"partial_flag": true,
"attr": [
{
"name": "origin_shape",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"):
"""CusFusedAbsMax1"""
return

View File

@ -1,87 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusImg2ColNC1HWC0"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusImg2ColNC1HWC0",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "img2colnc1hwc0.so",
"compute_cost": 10,
"kernel_name": "CusImg2ColNC1HWC0",
"partial_flag": true,
"attr": [
{
"name": "ksizes",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "strides",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "dilates",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "padding",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"):
"""CusImg2ColNC1HWC0"""
return

View File

@ -1,101 +0,0 @@
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCubeDenseLeft",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcubedenseleft.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCubeDenseLeft",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""CusMatMulCubeDenseLeft"""
return

View File

@ -1,102 +0,0 @@
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCubeFraczLeftCast",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcubefraczleftcast.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCubeFraczLeftCast",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float32"
],
"format": [
"FracZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="CusMatMulCubeFraczLeftCast"):
"""CusMatMulCubeFraczLeftCast"""
return

View File

@ -1,113 +0,0 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCubeFraczRightMul",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcubefraczrightmul.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCubeFraczRightMul",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x4",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""CusMatMulCubeFraczRightMul"""
return

View File

@ -1,114 +0,0 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCube",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcube.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCube",
"partial_flag": true,
"attr": [
{
"name": "transpose_a",
"param_type": "required",
"type": "bool",
"value": "all"
},
{
"name": "transpose_b",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"FRACTAL_NZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"):
"""CusMatMulCube"""
return

View File

@ -1,63 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusMatrixCombine"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusMatrixCombine",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matrixcombine.so",
"compute_cost": 10,
"kernel_name": "CusMatrixCombine",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"):
"""CusMatrixCombine"""
return

View File

@ -1,63 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusTranspose02314"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusTranspose02314",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "transpose02314.so",
"compute_cost": 10,
"kernel_name": "CusTranspose02314",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusTranspose02314(input_x, output, kernel_name="transpose021354"):
"""CusTranspose02314"""
return

View File

@ -23,6 +23,7 @@ from mindspore._checkparam import Validator as validator
# path of built-in op info register.
BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op"
def op_info_register(op_info):
@ -47,7 +48,10 @@ def op_info_register(op_info):
op_lib = Oplib()
file_path = os.path.realpath(inspect.getfile(func))
# keep the path custom ops implementation.
imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
if BUILT_IN_CUSTOM_OPS_REGISTER_PATH in file_path:
imply_path = file_path
else:
imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
if not op_lib.reg_op(op_info_real, imply_path):
raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real))

View File

@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop
from . import _quant_ops
from ._quant_ops import *
from .thor_ops import *
__all__ = [
'TensorAdd',

View File

@ -13,19 +13,51 @@
# limitations under the License.
# ============================================================================
"""thor_ops"""
import mindspore as ms
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.composite import multitype_ops as C
import mindspore as ms
__all__ = ["CusBatchMatMul",
"CusCholeskyTrsm",
"CusFusedAbsMax1",
"CusImg2Col",
"CusMatMulCubeDenseLeft",
"CusMatMulCubeFraczRightMul",
"CusMatMulCube",
"CusMatrixCombine",
"CusTranspose02314",
"CusMatMulCubeDenseRight",
"CusMatMulCubeFraczLeftCast",
]
class CusBatchMatMul(PrimitiveWithInfer):
"""CusMatMulCube definition"""
"""
Multiplies matrix `a` by matrix `b` in batch.
The rank of input tensors must be `3`.
Inputs:
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`.
- **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If
`transpose_b` is True.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, D, D)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
>>> cus_batch_matmul = P.CusBatchMatMul()
>>> output = cus_batch_matmul(input_x, input_y)
"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCube"""
"""init CusBatchMatMul"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.batch_matmul_impl import CusBatchMatMul
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
@ -40,13 +72,30 @@ class CusBatchMatMul(PrimitiveWithInfer):
class CusCholeskyTrsm(PrimitiveWithInfer):
"""CusCholeskyTrsm definition"""
"""
L * LT = A.
LT * (LT)^-1 = I.
return (LT)^-1.
Only compute the res of the diag part of input matrix with dim 128.
The rank of input tensors must be `2`.
Inputs:
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N // Split_dim, Split_dim, Split_dim)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32)
>>> cus_choleskytrsm = P.CusCholeskyTrsm()
>>> output = matmul(input_x)
"""
@prim_attr_register
def __init__(self):
"""init CusCholeskyTrsm"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import CusCholeskyTrsm
def infer_shape(self, data1_shape):
ll = []
m, _ = data1_shape
@ -61,14 +110,28 @@ class CusCholeskyTrsm(PrimitiveWithInfer):
class CusFusedAbsMax1(PrimitiveWithInfer):
"""CusCholeskyTrsm definition"""
"""
Compute the abs max of Tensor input.
The rank of input tensors must be `4` or `2`.
Inputs:
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)`
or math:`(32, 64)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`.
Examples:
>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
>>> cus_fused_abs_max1 = P.CusFusedAbsMax1()
>>> output = cus_fused_abs_max1(input_x)
"""
@prim_attr_register
def __init__(self, origin_shape=[-1, -1]):
"""init CusCholeskyTrsm"""
"""init CusFusedAbsMax1"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.origin_shape = origin_shape
from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
@ -88,7 +151,21 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
class CusImg2Col(PrimitiveWithInfer):
"""CusImg2Col definition"""
"""
Img2col the feature map and the result in reorganized in NC1HWC0.
Args:
- **strides** (listInt) - the stride of the ops.
- **ksizes** (listInt) - the kernel size of the ops.
Inputs:
- **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16)
>>> cusimg2col = P.CusImg2Col()
>>> output = cusimg2col(input_x)
"""
@prim_attr_register
def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"):
@ -98,7 +175,7 @@ class CusImg2Col(PrimitiveWithInfer):
self.strides = strides
self.dilates = dilates
self.mode = mode
from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
@ -122,13 +199,30 @@ class CusImg2Col(PrimitiveWithInfer):
class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
"""CusMatMulCube definition"""
"""
Multiplies matrix `a` by matrix `b`.
The rank of input_x1 must be `4`, the fractal format of the normal matrix.
The rank of input_x2 must be `2`.
Inputs:
- **input_x1** (Tensor) - The first tensor to be multiplied.
The shape of the tensor is :math:`(N0, M0, N1, M1)`.
- **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, C)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
>>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
>>> matmulcubedenseleft = P.CusMatMulCubeDenseLeft()
>>> output = matmulcubedenseleft(input_x, input_y)
"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCube"""
"""init CusMatMulCubeDenseLeft"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import CusMatMulCubeDenseLeft
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
@ -143,13 +237,32 @@ class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
"""CusMatMulCubeFraczRightMul definition"""
"""
Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`.
The rank of input_x1 tensors must be `2`.
The rank of input_x2 tensors must be `4`.
Inputs:
- **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`.
- **input_x2** (Tensor) - The second tensor to be multiplied.
The shape of the tensor is :math:`(C1, M1, C0, M0)`.
- **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, M)`.
Examples:
>>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
>>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
>>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16)
>>> cusmatmulfraczrightmul = P.CusMatMulCubeFraczRightMul()
>>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3)
"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCubeFraczRightMul"""
self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul
def get_bprop(self):
def bprop(x1, x2, x3, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3))
@ -164,7 +277,30 @@ class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
class CusMatMulCube(PrimitiveWithInfer):
"""CusMatMulCube definition"""
"""
Multiplies matrix `a` by matrix `b`.
The rank of input tensors must be `2`.
Args:
transpose_a (bool): If True, `a` is transposed before multiplication. Default: False.
transpose_b (bool): If True, `b` is transposed before multiplication. Default: False.
Inputs:
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If
`transpose_a` is True, its shape should be :math:`(N, C)` after transposing.
- **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If
`transpose_b` is True, its shape should be :math:`(C, M)` after transpose.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, M)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
>>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
>>> cusmatmulcube = P.CusMatMulCube()
>>> output = matmul(input_x, input_y)
"""
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False):
@ -172,7 +308,7 @@ class CusMatMulCube(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
self.transpose_a = transpose_a
self.transpose_b = transpose_b
from mindspore.ops._op_impl._custom_op.matmul_cube_impl import CusMatMulCube
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
@ -199,13 +335,27 @@ class CusMatMulCube(PrimitiveWithInfer):
class CusMatrixCombine(PrimitiveWithInfer):
"""CusMatMulCube definition"""
"""
move the batch matrix to result matrix diag part.
The rank of input tensors must be `3`.
Inputs:
- **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N * D, N * D)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
>>> cusmatrixcombine = P.CusMatrixCombine()
>>> output = cusmatrixcombine(input_x)
"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCube"""
"""init CusMatrixCombine"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.matrix_combine_impl import CusMatrixCombine
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
@ -223,13 +373,28 @@ class CusMatrixCombine(PrimitiveWithInfer):
class CusTranspose02314(PrimitiveWithInfer):
"""CusTranspose02314 definition"""
"""
Permute input tensor with perm (0, 2, 3, 1, 4)
The rank of input tensors must be `5` with format NC1HWC0.
Inputs:
- **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16)
>>> custranspose02314 = P.CusTranspose02314()
>>> output = custranspose02314(input_x)
"""
@prim_attr_register
def __init__(self):
"""init CusTranspose02314"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
@ -246,3 +411,83 @@ class CusTranspose02314(PrimitiveWithInfer):
def infer_dtype(self, data1_dtype):
return data1_dtype
class CusMatMulCubeDenseRight(PrimitiveWithInfer):
"""
Multiplies matrix `a` by matrix `b`.
The rank of input_x1 tensor must be `2`.
The rank of input_x2 tensor must be `4`.
Inputs:
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`.
- **input_y** (Tensor) - The second tensor to be multiplied.
The shape of the tensor is :math:`(C1, M1, M0, C0)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, M)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
>>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
>>> cusmatmulcubedenseright = P.CusMatMulCubeDenseRight()
>>> output = cusmatmulcubedenseright(input_x, input_y)
"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCubeDenseRight"""
self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import CusMatMulCubeDenseRight
def get_bprop(self):
def bprop(x1, x2, x3, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3))
return bprop
def infer_shape(self, data1_shape, data2_shape, data3_shape):
return data1_shape
def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype):
return ms.common.dtype.tensor_type(getattr(ms, "float32"))
class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer):
"""
Multiplies matrix `a` by matrix `b`.
The rank of input_x1 tensor must be `4`.
The rank of input_x2 tensors must be `2`.
Inputs:
- **input_x1** (Tensor) - The first tensor to be multiplied.
The shape of the tensor is :math:`(C1, N1, N0, C0)`.
- **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, M)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
>>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
>>> cusmatmulcubefraczleftcast = P.CusMatMulCubeFraczLeftCast()
>>> output = cusmatmulcubefraczleftcast(input_x, input_y)
"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCubeFraczLeftCast"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
return bprop
def infer_shape(self, data1_shape, data2_shape):
return data2_shape
def infer_dtype(self, data1_dtype, data2_dtype):
return ms.common.dtype.tensor_type(getattr(ms, "float16"))