add ascend costomized kernel builder module

This commit is contained in:
zhaizhiqiang 2022-12-19 14:35:30 +08:00
parent d7eceaead0
commit c74909320b
25 changed files with 791 additions and 0 deletions

View File

@ -71,3 +71,6 @@
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "knownConditionTrueFalse"
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "shadowVariable"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_utils.cc" "knownConditionTrueFalse"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/op_proto/add_dsl.cc" "syntaxError"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/op_proto/matmul_tik.cc" "syntaxError"

View File

@ -86,3 +86,23 @@
"mindspore/mindspore/lite/src/litert/delegate/nnapi/nnapi_implementation.cc" "build/include_order"
"mindspore/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc" "whitespace/parens"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/gemm_mask_avx512/" "runtime/int"
# ascend samples
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "build/include_subdir"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "build/include_subdir"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include_subdir"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "runtime/references"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "runtime/references"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "runtime/references"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "whitespace/comments"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "whitespace/comments"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "whitespace/comments"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "legal/copyright"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "legal/copyright"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "legal/copyright"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "whitespace/ending_newline"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "whitespace/ending_newline"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "whitespace/ending_newline"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "build/include"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "build/include"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include"

View File

@ -184,3 +184,10 @@
"mindspore/mindspore/lite/python/api/tensor.py" "protected-access"
"mindspore/mindspore/lite/test" "missing-docstring"
"mindspore/mindspore/lite/test" "unused-variable"
# ascend samples
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "wrong-import-order"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "wrong-import-order"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "bad-whitespace"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "bad-whitespace"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "bad-continuation"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "bad-continuation"

View File

@ -0,0 +1,2 @@
cmake_minimum_required(VERSION 3.12)
project(MS_ASCEND_CUSTOM_KERNEL_INSTALLER)

View File

@ -0,0 +1,15 @@
Build Ascend customized kernel.
More details please refer to https://gitee.com/ascend/samples.git.
## build
mkdir build
cd build
cmake ../
make
## install
./ms_ascend_custom_kernel_installer.run
After install, you can use converter tools to convert model with customized kernel on Ascend developing env.

View File

@ -0,0 +1,9 @@
[ReshapeCust]
opInfo.engine=DNN_VM_AICPU
opInfo.flagPartial=False
opInfo.computeCost=100
opInfo.flagAsync=False
opInfo.opKernelLib=CUSTAICPUKernel
opInfo.kernelSo=libcust_aicpu_kernels.so
opInfo.functionName=RunCpuKernel
opInfo.workspaceSize=1024

View File

@ -0,0 +1,41 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
* Description: implement of sample
*/
#include "reshape_cust_kernels.h"
#include <string.h>
#include "cpu_types.h"
namespace {
const char *RESHAPE_CUST = "ReshapeCust";
}
namespace aicpu {
uint32_t ReshapeCustCpuKernel::Compute(CpuKernelContext &ctx) {
Tensor *input_tensor = ctx.Input(0);
if (input_tensor == nullptr) {
return -1;
}
Tensor *output_tensor = ctx.Output(0);
if (output_tensor == nullptr) {
return -1;
}
auto input_data = input_tensor->GetData();
if (input_data == nullptr) {
return -1;
}
auto output_data = output_tensor->GetData();
if (output_data == nullptr) {
return -1;
}
uint64_t data_size = input_tensor->GetDataSize();
memcpy(output_data, input_data, data_size);
return 0;
}
REGISTER_CPU_KERNEL(RESHAPE_CUST, ReshapeCustCpuKernel);
} // namespace aicpu

View File

@ -0,0 +1,28 @@
/* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef _AICPU_RESHAPE_CUST_KERNELS_H_
#define _AICPU_RESHAPE_CUST_KERNELS_H_
#include "cpu_kernel.h"
namespace aicpu {
class ReshapeCustCpuKernel : public CpuKernel {
public:
~ReshapeCustCpuKernel() = default;
uint32_t Compute(CpuKernelContext &ctx) override;
};
} // namespace aicpu
#endif

View File

@ -0,0 +1,2 @@
#!/bin/bash
#install mindspore ascend customized kernel

View File

@ -0,0 +1,19 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
set(CMAKE_CXX_COMPILER g++)
set(CMAKE_C_COMPILER gcc)
# add source files
aux_source_directory(. SRCS)
if("x${SRCS}" STREQUAL "x")
add_custom_target(${OP_PROTO_TARGET}
COMMAND mkdir -p ${OP_PROTO_TARGET_OUT_DIR}
COMMAND echo "no source to make lib${OP_PROTO_TARGET}.so")
return(0)
endif()
set(LIBRARY_OUTPUT_PATH ${OP_PROTO_TARGET_OUT_DIR})
message(STATUS "OP_PROTO_TARGET=${OP_PROTO_TARGET}")
add_library(${OP_PROTO_TARGET} SHARED ${SRCS})
target_link_libraries(${OP_PROTO_TARGET} ${ASCEND_INC}/../lib64/libgraph.so)

View File

@ -0,0 +1,94 @@
/**
* Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*
* @file add_dsl.cpp
*
* @brief
*
* @version 1.0
*
*/
#include "./add_dsl.h"
#include <string>
#include <vector>
namespace ge {
bool InferShapeAndTypeAdd(Operator &op, const string &input_name1, const string &input_name2,
const string &output_name) {
// vOutputDesc.push_back(op.GetInputDesc(0));
TensorDesc vOutputDesc = op.GetOutputDescByName(output_name.c_str());
DataType input_dtype = op.GetInputDescByName(input_name1.c_str()).GetDataType();
Format input_format = op.GetInputDescByName(input_name1.c_str()).GetFormat();
// 针对shape维度大小进行交换
ge::Shape shapeX = op.GetInputDescByName(input_name1.c_str()).GetShape();
ge::Shape shapeY = op.GetInputDescByName(input_name2.c_str()).GetShape();
std::vector<int64_t> dimsX = shapeX.GetDims();
std::vector<int64_t> dimsY = shapeY.GetDims();
if (dimsX.size() < dimsY.size()) {
std::vector<int64_t> dimsTmp = dimsX;
dimsX = dimsY;
dimsY = dimsTmp;
}
// 对小的shape进行1补齐
if (dimsX.size() != dimsY.size()) {
int dec = dimsX.size() - dimsY.size();
for (int i = 0; i < dec; i++) {
dimsY.insert(dimsY.begin(), (int64_t)1);
}
}
// 设置输出的shape维度
std::vector<int64_t> dimVec;
for (size_t i = 0; i < dimsX.size(); i++) {
if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) {
return false;
}
int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
dimVec.push_back(dims);
}
ge::Shape outputShape = ge::Shape(dimVec);
vOutputDesc.SetShape(outputShape);
vOutputDesc.SetDataType(input_dtype);
vOutputDesc.SetFormat(input_format);
op.UpdateOutputDesc(output_name.c_str(), vOutputDesc);
return true;
}
//----------------Add-------------------
IMPLEMT_VERIFIER(AddDsl, AddVerify) {
if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(AddInferShape) {
if (InferShapeAndTypeAdd(op, "x1", "x2", "y")) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
// Registered inferfunction
COMMON_INFER_FUNC_REG(AddDsl, AddInferShape);
// Registered verify function
VERIFY_FUNC_REG(AddDsl, AddVerify);
//----------------Add-------------------
} // namespace ge

View File

@ -0,0 +1,35 @@
/**
* Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*
* @file add_dsl.h
*
* @brief
*
* @version 1.0
*
*/
#ifndef GE_OPS_OP_PROTO_ADDDSL_H_
#define GE_OPS_OP_PROTO_ADDDSL_H_
#include "graph/operator_reg.h"
namespace ge {
REG_OP(AddDsl)
.INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE,
DT_COMPLEX128, DT_COMPLEX64, DT_STRING}))
.INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE,
DT_COMPLEX128, DT_COMPLEX64, DT_STRING}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE,
DT_COMPLEX128, DT_COMPLEX64, DT_STRING}))
.OP_END_FACTORY_REG(AddDsl)
}
#endif // GE_OPS_OP_PROTO_ADDDSL_H_

View File

@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
This program is free software; you can redistribute it and/or modify
it under the terms of the Apache License Version 2.0.You may not use this file
except in compliance with the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Apache License for more details at
http://www.apache.org/licenses/LICENSE-2.0
add
"""
from __future__ import absolute_import
import tbe.dsl as tbe
from functools import reduce
from tbe import tvm
from tbe.common.register import register_op_compute
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
# General limitation of the reduce size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
# pylint: disable=locally-disabled,too-many-arguments,unused-argument
@register_op_compute("Add", op_mode="dynamic", support_fusion=True)
def add_compute(input_x, input_y, output_z, kernel_name="add"):
"""
calculating data's add, c = a + b
Parameters
----------
input_x: TVM tensor
the placeholder of first input data
input_y: TVM tensor
the placeholder of second input data
output_data: dict
shape and dtype of output, should be broadcast shape and type as input
kernel_name: str
cce kernel name, default value is add
Returns
-------
res : output of the data's add
"""
shape_x = shape_util.shape_to_list(input_x.shape)
shape_y = shape_util.shape_to_list(input_y.shape)
shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y,
param_name_input1="input_x",
param_name_input2="input_y")
shape_size = reduce(lambda x, y: x * y, shape_max[:])
if shape_size > SHAPE_SIZE_LIMIT:
raise RuntimeError("the shape is too large to calculate")
input_x = tbe.broadcast(input_x, shape_max)
input_y = tbe.broadcast(input_y, shape_max)
res = tbe.vadd(input_x, input_y)
return res
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT,
para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME)
def add_dsl(input_x, input_y, output_z, kernel_name="add_dsl"):
"""
algorithm: add
calculating data's add, c = a + b
Parameters
----------
input_x : dict
shape and dtype of first input, only support float16, float32, int32
input_y : dict
shape and dtype of second input, only support float16, float32, int32
output_z: dict
shape and dtype of output, should be broadcast shape and type as input
kernel_name : str
cce kernel name, default value is add
Returns
-------
None
"""
shape_x = input_x.get("shape")
shape_y = input_y.get("shape")
check_tuple = ("float16", "float32", "int32")
input_data_type = input_x.get("dtype").lower()
para_check.check_dtype(input_data_type, check_tuple, param_name="input_x")
shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y,
param_name_input1="input_x",
param_name_input2="input_y")
if shape_x[-1] == 1 and shape_y[-1] == 1 and shape_max[-1] == 1:
shape_x = shape_x if len(shape_x) == 1 else shape_x[:-1]
shape_y = shape_y if len(shape_y) == 1 else shape_y[:-1]
shape_max = shape_max if len(shape_max) == 1 else shape_max[:-1]
data_x = tvm.placeholder(shape_x, name="data_1", dtype=input_data_type)
data_y = tvm.placeholder(shape_y, name="data_2", dtype=input_data_type)
res = add_compute(data_x, data_y, output_z, kernel_name)
with tvm.target.cce():
schedule = tbe.auto_schedule(res)
config = {"name": kernel_name,
"tensor_list": (data_x, data_y, res)}
tbe.build(schedule, config)

View File

@ -0,0 +1,18 @@
[AddDsl]
input0.name=x1
input0.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
input0.shape=all
input0.paramType=required
input0.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
input1.name=x2
input1.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
input1.shape=all
input1.paramType=required
input1.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
output0.name=y
output0.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
output0.shape=all
output0.paramType=required
output0.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
opFile.value=add_dsl
opInterface.value=add_dsl

View File

@ -0,0 +1,18 @@
[AddDsl]
input0.name=x1
input0.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
input0.shape=all
input0.paramType=required
input0.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
input1.name=x2
input1.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
input1.shape=all
input1.paramType=required
input1.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
output0.name=y
output0.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
output0.shape=all
output0.paramType=required
output0.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
opFile.value=add_dsl
opInterface.value=add_dsl

View File

@ -0,0 +1,18 @@
[AddDsl]
input0.name=x1
input0.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
input0.shape=all
input0.paramType=required
input0.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
input1.name=x2
input1.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
input1.shape=all
input1.paramType=required
input1.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
output0.name=y
output0.dtype=float16,float16,float16,float16,float,float,float,float,int32,int32,int32,int32
output0.shape=all
output0.paramType=required
output0.format=NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND
opFile.value=add_dsl
opInterface.value=add_dsl

View File

@ -0,0 +1,19 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
set(CMAKE_CXX_COMPILER g++)
set(CMAKE_C_COMPILER gcc)
# add source files
aux_source_directory(. SRCS)
if("x${SRCS}" STREQUAL "x")
add_custom_target(${OP_PROTO_TARGET}
COMMAND mkdir -p ${OP_PROTO_TARGET_OUT_DIR}
COMMAND echo "no source to make lib${OP_PROTO_TARGET}.so")
return(0)
endif()
set(LIBRARY_OUTPUT_PATH ${OP_PROTO_TARGET_OUT_DIR})
message(STATUS "OP_PROTO_TARGET=${OP_PROTO_TARGET}")
add_library(${OP_PROTO_TARGET} SHARED ${SRCS})
target_link_libraries(${OP_PROTO_TARGET} ${ASCEND_INC}/../lib64/libgraph.so)

View File

@ -0,0 +1,42 @@
#include "matmul_tik.h"
#include <string>
#include <vector>
namespace ge {
IMPLEMT_VERIFIER(MatmulTik, MatmulTikVerify) {
std::vector<DataType> support_list;
support_list.reserve(5);
support_list.push_back(DT_FLOAT16);
support_list.push_back(DT_FLOAT);
support_list.push_back(DT_INT32);
support_list.push_back(DT_INT8);
support_list.push_back(DT_UINT8);
return GRAPH_SUCCESS;
}
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(MatmulTikInferShape) {
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
ge::TensorDesc inputTensorDescX = op.GetInputDescByName("x1");
ge::TensorDesc inputTensorDescY = op.GetInputDescByName("x2");
ge::Shape shapeX = inputTensorDescX.GetShape();
ge::Shape shapeY = inputTensorDescY.GetShape();
DataType dtype = inputTensorDescX.GetDataType();
std::vector<int64_t> dimVector;
dimVector.push_back(shapeX.GetDim(0));
dimVector.push_back(shapeY.GetDim(1));
ge::Shape outputShape(dimVector);
tensordesc_output.SetShape(outputShape);
tensordesc_output.SetDataType(dtype);
(void)op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
// Registered inferfunction
COMMON_INFER_FUNC_REG(MatmulTik, MatmulTikInferShape);
// Registered verify function
VERIFY_FUNC_REG(MatmulTik, MatmulTikVerify);
} // namespace ge

View File

@ -0,0 +1,14 @@
#ifndef GE_OP_MATMULTIK_H
#define GE_OP_MATMULTIK_H
#include "graph/operator_reg.h"
namespace ge {
REG_OP(MatmulTik)
.INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
.INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
.OP_END_FACTORY_REG(MatmulTik)
}
#endif // GE_OP_MATMULTIK_H

View File

@ -0,0 +1,210 @@
"""
Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul_tik
"""
from tbe import tik
from tbe.common.platform import get_soc_spec
DTYPE_SIZE = {
'bool': 1,
'uint8': 1,
'int8': 1,
'uint16': 2,
'int16': 2,
'int24': 3,
'uint32': 4,
'int32': 4,
'float16': 2,
'float32': 4,
'int48': 6,
'int64': 8,
'uint64': 8,
'float64':8
}
def MK_TO_K1MK0(tik_instance, mk_input_tensor, k1mk0_tensor, dtype, k1, m, k0):
"""data move mk to k1mk0"""
src_ub = tik_instance.Tensor(dtype, (k1, m, k0), name='src_ub', scope=tik.scope_ubuf)
# data_move(m, k) ---> (k1, m, k0)
with tik_instance.for_range(0, k1) as i:
tik_instance.data_move(src_ub[i * m * k0:], mk_input_tensor[i * k0:], 0, m, k0 * DTYPE_SIZE[dtype] // 32,
(k1 - 1) * k0 * DTYPE_SIZE[dtype] // 32, 0)
tik_instance.data_move(k1mk0_tensor, src_ub, 0, 1, k1 * m * k0 * DTYPE_SIZE[dtype] // 32, 0, 0)
def KN_TO_K1NK0(tik_instance, kn_input_tensor, k1nk0_tensor, dtype, k1, n, k0):
"""data move kn to k1nk0"""
with tik_instance.for_range(0, k1) as index:
k1nk0_ub = tik_instance.Tensor(dtype, (n, k0), tik.scope_ubuf, "k1nk0_ub")
src_ub = tik_instance.Tensor(dtype, (k0, n), tik.scope_ubuf, "src_ub")
burst_len = k0 * n * DTYPE_SIZE[dtype] // 32
tik_instance.data_move(src_ub, kn_input_tensor[index * k0 * n], 0, 1, burst_len, 0, 0)
dst_list = [k1nk0_ub[16 * i] for i in range(16)]
src_list = [src_ub[n * i] for i in range(16)]
rep_times = n // k0
dst_rep_stride = k0
src_rep_stride = 1
tik_instance.vec_trans_scatter(False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride)
tik_instance.data_move(k1nk0_tensor[index * k0 * n], k1nk0_ub, 0, 1, burst_len, 0, 0)
def N1MN0_TO_MN(tik_instance, mn_output_tensor, n1mn0_tensor, dtype, n1, m, n0):
"""data move mn to n1mn0"""
src_ub = tik_instance.Tensor(dtype, (m, n1 * n0), name='src_ub', scope=tik.scope_ubuf)
# data_move(n1, m, n0) ---> (m, n)
with tik_instance.for_range(0, n1) as i:
tik_instance.data_move(src_ub[i * n0:], n1mn0_tensor[i * m * n0:], 0, m,
n0 * DTYPE_SIZE[dtype] // 32, 0, (n1 - 1) * n0 * DTYPE_SIZE[dtype] // 32)
tik_instance.data_move(mn_output_tensor, src_ub, 0, 1, m * n1 * n0 * DTYPE_SIZE[dtype] // 32, 0, 0)
def matmul_tik_compute(params, kernel_name):
"""
matmul tik compute
@param params: matmul data
@param kernel_name: kernel name
@return: tik instance
"""
tik_instance = tik.Tik()
if not isinstance(params, dict):
params = params.__dict__
m_size, k_size, n_size = params['M'], params['K'], params['N']
data_type = params["data_type"]
m_tiling_size = int(params["m_tiling_size"])
n_tiling_size = int(params["n_tiling_size"])
k_tiling_size = int(params['k_tiling_size'])
m_cycle_times = params["m_cycle_times"]
n_cycle_times = params["n_cycle_times"]
k_cycle_times = params["k_cycle_times"]
# Determine the output type
if data_type == "float16":
if get_soc_spec("SOC_VERSION") in ["SD3403", "OPTG", "Hi3796CV300CS", "TsnsC"]:
C_loc_out_type = "float16"
else:
C_loc_out_type = "float32"
K0 = 16
else:
C_loc_out_type = "int32"
K0 = 32
block_size = 16
n_thread_num = params['n_thread_num']
m_thread_num = params['m_thread_num']
k_thread_num = params['k_thread_num']
mk_gm_input = tik_instance.Tensor(data_type, (m_size, k_size), name="mk_input_gm", scope=tik.scope_gm)
kn_gm_input = tik_instance.Tensor(data_type, (k_size, n_size), name="kn_input_gm", scope=tik.scope_gm)
k1mk0_workspace = tik_instance.Tensor(data_type, (k_size // K0, m_size, K0), name="k1mk0_workspace",
scope=tik.scope_gm, is_workspace=True)
k1nk0_workspace = tik_instance.Tensor(data_type, (k_size // K0, n_size, K0), name="k1nk0_workspace",
scope=tik.scope_gm, is_workspace=True)
mn_gm_output = tik_instance.Tensor(C_loc_out_type, (m_size, n_size), tik.scope_gm, name="mn_output_gm")
nmk0_workspace = tik_instance.Tensor(C_loc_out_type, (n_size // block_size, m_size, block_size),
name="nmk0_workspace", scope=tik.scope_gm, is_workspace=True)
MK_TO_K1MK0(tik_instance, mk_gm_input, k1mk0_workspace, data_type, k_size // K0, m_size, K0)
KN_TO_K1NK0(tik_instance, kn_gm_input, k1nk0_workspace, data_type, k_size // K0, n_size, K0)
# Tiling is realized through the for_range() loop.
with tik_instance.for_range(0, 2, block_num = 1) as core_id:
with tik_instance.for_range(0, n_cycle_times // 2, thread_num=n_thread_num) as n_idx:
with tik_instance.for_range(0, m_cycle_times, thread_num=m_thread_num) as m_idx:
dst_l0c = tik_instance.Tensor(C_loc_out_type, [n_tiling_size // 16, m_tiling_size, 16], name='dst_l0c',
scope=tik.scope_cbuf_out)
with tik_instance.for_range(0, k_cycle_times,
thread_num=k_thread_num) as k_idx:
# Calculation result data transfer.
inputa_l1 = tik_instance.Tensor(params['data_type'], [k_tiling_size // K0, m_tiling_size, K0],
name="A_tiling_l1", scope=tik.scope_cbuf)
tik_instance.data_move(inputa_l1,
k1mk0_workspace[k_idx * k_tiling_size // K0, m_idx * m_tiling_size, :],
0, k_tiling_size // K0, m_tiling_size, m_size - m_tiling_size, 0)
inputb_l1 = tik_instance.Tensor(params["data_type"], [k_tiling_size // K0, n_tiling_size, K0],
name="B_tiling_l1", scope=tik.scope_cbuf)
if n_size - n_tiling_size > 65535:
with tik_instance.for_range(0, k_tiling_size // K0) \
as dma_k_idx:
tik_instance.data_move(inputb_l1[dma_k_idx, :, :],
k1nk0_workspace[k_idx * k_tiling_size // K0 + dma_k_idx,
(core_id * n_cycle_times // 2 + n_idx) * n_tiling_size, :],
0, 1, n_tiling_size, 0, 0)
else:
tik_instance.data_move(inputb_l1, k1nk0_workspace[k_idx * k_tiling_size // K0,
(core_id * n_cycle_times // 2 + n_idx) * n_tiling_size, :],
0, k_tiling_size // K0, n_tiling_size, n_size - n_tiling_size, 0)
# Call matmul API to matrix multiplication calculation.
with tik_instance.if_scope(k_idx == 0):
tik_instance.matmul(dst_l0c, inputa_l1, inputb_l1, m_tiling_size, k_tiling_size, n_tiling_size,
init_l1out=True)
with tik_instance.else_scope():
tik_instance.matmul(dst_l0c, inputa_l1, inputb_l1, m_tiling_size, k_tiling_size, n_tiling_size,
init_l1out=False)
tik_instance.fixpipe(nmk0_workspace[n_tiling_size // 16 * (core_id * n_cycle_times // 2 + n_idx),
m_idx * m_tiling_size, :], dst_l0c, n_tiling_size // 16, m_tiling_size * 16 *
DTYPE_SIZE[C_loc_out_type]//32,
(m_size - m_tiling_size) * 16 * DTYPE_SIZE[C_loc_out_type] // 32, 0)
N1MN0_TO_MN(tik_instance, mn_gm_output, nmk0_workspace, C_loc_out_type, n_size // K0, m_size, K0)
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[mk_gm_input, kn_gm_input], outputs=[mn_gm_output])
return tik_instance
def matmul_tik(input_x1, input_x2, output_y=None, kernel_name="simple_matmul"):
"""
matmul_tik main func
Parameters
----------
input_x1: input data 1
input_x2: input data 2
output_y: output dta
"""
shape_a = input_x1.get("ori_shape")
shape_b = input_x2.get("ori_shape")
output_y = output_y
m = shape_a[0]
k = shape_a[1]
n = shape_b[1]
data_type = input_x1.get("dtype").lower()
params = {
'M': m,
'K': k,
'N': n,
'data_type': data_type,
'm_tiling_size': 16,
'm_cycle_times': 1,
'm_thread_num': 1,
'n_tiling_size': 64,
'n_cycle_times': 16,
'n_thread_num': 1,
'k_tiling_size': 32,
'k_cycle_times': 2,
'k_thread_num': 2,
'output_y':output_y
}
return matmul_tik_compute(params, kernel_name)

View File

@ -0,0 +1,20 @@
[MatmulTik]
input0.name=x1
input0.dtype=int8,uint8,float16
input0.shape=all
input0.needCompile=false
input0.paramType=required
input0.format=ND,ND,ND
input1.name=x2
input1.dtype=int8,int8,float16
input1.shape=all
input1.needCompile=false
input1.paramType=required
input1.format=ND,ND,ND
output0.name=y
output0.dtype=int32,int32,float
output0.shape=all
output0.paramType=required
output0.format=ND,ND,ND
opFile.value=matmul_tik
opInterface.value=matmul_tik

View File

@ -0,0 +1,20 @@
[MatmulTik]
input0.name=x1
input0.dtype=int8,uint8,float16
input0.shape=all
input0.needCompile=false
input0.paramType=required
input0.format=ND,ND,ND
input1.name=x2
input1.dtype=int8,int8,float16
input1.shape=all
input1.needCompile=false
input1.paramType=required
input1.format=ND,ND,ND
output0.name=y
output0.dtype=int32,int32,float
output0.shape=all
output0.paramType=required
output0.format=ND,ND,ND
opFile.value=matmul_tik
opInterface.value=matmul_tik

View File

@ -0,0 +1,20 @@
[MatmulTik]
input0.name=x1
input0.dtype=int8,uint8,float16
input0.shape=all
input0.needCompile=false
input0.paramType=required
input0.format=ND,ND,ND
input1.name=x2
input1.dtype=int8,int8,float16
input1.shape=all
input1.needCompile=false
input1.paramType=required
input1.format=ND,ND,ND
output0.name=y
output0.dtype=int32,int32,float
output0.shape=all
output0.paramType=required
output0.format=ND,ND,ND
opFile.value=matmul_tik
opInterface.value=matmul_tik