forked from OSSInnovation/mindspore
!6080 [MS][LITE][Develop]support nnacl internal kernels
Merge pull request !6080 from chenjianping/lite_dev2
This commit is contained in:
commit
9e104137ac
|
@ -8,8 +8,10 @@ file(GLOB_RECURSE C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
|||
file(GLOB KERNEL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32_grad/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/fp32/*.cc
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c)
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ typedef struct LiteSession {
|
|||
/// \param[in] inputs Define the new inputs shape.
|
||||
///
|
||||
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
||||
int Resize(const TensorPtrVector &inputs);
|
||||
int Resize(const TensorPtrVector &inputs, Int32VectorVector dims);
|
||||
} LiteSession;
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
||||
|
|
|
@ -27,5 +27,6 @@ using String = std::string;
|
|||
using StringVector = std::vector<std::string>;
|
||||
using ShapeVector = std::vector<int>;
|
||||
using NodePtrVector = std::vector<struct Node *>;
|
||||
|
||||
using Int32Vector = std::vector<int32_t>;
|
||||
using Int32VectorVector = std::vector<Int32Vector>;
|
||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_
|
||||
|
|
|
@ -27,6 +27,183 @@ enum NodeType {
|
|||
NodeType_MAX = NodeType_CNode
|
||||
};
|
||||
|
||||
enum KernelType {
|
||||
Concat,
|
||||
SoftMax,
|
||||
Activation,
|
||||
Conv2D,
|
||||
FusedBatchNorm,
|
||||
BatchNorm,
|
||||
BiasAdd,
|
||||
Pooling,
|
||||
ROIPooling,
|
||||
DepthwiseConv2D,
|
||||
DeDepthwiseConv2D,
|
||||
Resize,
|
||||
DetectionPostProcess,
|
||||
FullConnection,
|
||||
Mean,
|
||||
DeConv2D,
|
||||
Scale,
|
||||
Reshape,
|
||||
Eltwise,
|
||||
NetOutput,
|
||||
Add,
|
||||
Sub,
|
||||
MatMul,
|
||||
StridedSlice,
|
||||
Power,
|
||||
Slice,
|
||||
Stack,
|
||||
Mul,
|
||||
RealDiv,
|
||||
Pad,
|
||||
Maximum,
|
||||
Minimum,
|
||||
PReLU,
|
||||
LeakyReLU,
|
||||
ArgMax,
|
||||
ArgMin,
|
||||
Exp,
|
||||
Crop,
|
||||
Range,
|
||||
Rsqrt,
|
||||
ExpandDims,
|
||||
Tile,
|
||||
Cast,
|
||||
Shape,
|
||||
Nchw2Nhwc,
|
||||
Nhwc2Nchw,
|
||||
QuantDTypeCast,
|
||||
Split,
|
||||
Permute,
|
||||
FakeQuantWithMinMaxVars,
|
||||
Equal,
|
||||
Less,
|
||||
Greater,
|
||||
NotEqual,
|
||||
LessEqual,
|
||||
GreaterEqual,
|
||||
Min,
|
||||
Floor,
|
||||
Abs,
|
||||
Neg,
|
||||
Cos,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Square,
|
||||
Constant,
|
||||
Log,
|
||||
Tan,
|
||||
Atan,
|
||||
Asin,
|
||||
Clip,
|
||||
Transpose,
|
||||
Squeeze,
|
||||
Unsqueeze,
|
||||
Upsample,
|
||||
Dropout,
|
||||
Broadcast,
|
||||
BroadcastTo,
|
||||
Lrn,
|
||||
ZerosLike,
|
||||
TopK,
|
||||
SpaceToDepth,
|
||||
SpaceToBatch,
|
||||
SparseToDense,
|
||||
ReverseSequence,
|
||||
Rank,
|
||||
Gather,
|
||||
GatherNd,
|
||||
Fill,
|
||||
Elu,
|
||||
DepthToSpace,
|
||||
BatchToSpace,
|
||||
AddN,
|
||||
Ceil,
|
||||
EmbeddingLookup,
|
||||
EmbeddingLookupSparse,
|
||||
FloorDiv,
|
||||
FloorMod,
|
||||
L2Norm,
|
||||
LocalResponseNormalization,
|
||||
MatrixDiag,
|
||||
Reduce,
|
||||
Reverse,
|
||||
Round,
|
||||
Select,
|
||||
Scatter,
|
||||
ScatterND,
|
||||
ConstantOfShape,
|
||||
Unique,
|
||||
Unstack,
|
||||
LogicalAnd,
|
||||
LogicalOr,
|
||||
LogicalXor,
|
||||
LogicalNot,
|
||||
OnnxInt8Quantize,
|
||||
OnnxInt8Dequantize,
|
||||
FakeQuantWithMinMax,
|
||||
FakeQuantWithMinMaxPerChannel,
|
||||
BatchNormFold,
|
||||
MulFold,
|
||||
AddFold,
|
||||
SquaredDifference,
|
||||
Flatten,
|
||||
FlattenGrad,
|
||||
TupleGetItem,
|
||||
Div,
|
||||
Where,
|
||||
OneHot,
|
||||
Lstm,
|
||||
Conv2DGradFilter,
|
||||
Conv2DGradInput,
|
||||
PoolingGrad,
|
||||
BNGrad,
|
||||
BNGradInput,
|
||||
ApplyMomentum,
|
||||
BiasGrad,
|
||||
SoftmaxCrossEntropy,
|
||||
AddGrad,
|
||||
SubGrad,
|
||||
MulGrad,
|
||||
DivGrad,
|
||||
PowerGrad,
|
||||
ActivationGrad,
|
||||
PriorBox,
|
||||
SpaceToBatchND,
|
||||
Depend,
|
||||
Return,
|
||||
MakeTuple,
|
||||
ToFormat,
|
||||
Proposal,
|
||||
Custom,
|
||||
BlackBox,
|
||||
NegGrad,
|
||||
LogGrad,
|
||||
BatchToSpaceND,
|
||||
};
|
||||
|
||||
enum ActivationType {
|
||||
NO_ACTIVATION = 0,
|
||||
RELU = 1,
|
||||
SIGMOID = 2,
|
||||
RELU6 = 3,
|
||||
ELU = 4,
|
||||
LEAKY_RELU = 5,
|
||||
ABS = 6,
|
||||
RELU1 = 7,
|
||||
SOFTSIGN = 8,
|
||||
SOFTPLUS = 9,
|
||||
TANH = 10,
|
||||
SELU = 11,
|
||||
HSWISH = 12,
|
||||
HSIGMOID = 13,
|
||||
THRESHOLDRELU = 14,
|
||||
LINEAR = 15,
|
||||
UNKNOW = 16
|
||||
};
|
||||
|
||||
typedef struct Node {
|
||||
String name_;
|
||||
NodeType node_type_;
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "internal/src/kernel/fp32/activation.h"
|
||||
#include "internal/include/errorcode.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "nnacl/fp32/activation.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator) {
|
||||
ActivationParameter *param = (ActivationParameter *)node->primitive_;
|
||||
int ret = RET_OK;
|
||||
size_t length = in_tensors[0]->ElementsNum();
|
||||
float *input_addr = (float *)in_tensors[0]->data_;
|
||||
float *output_addr = (float *)out_tensors[0]->data_;
|
||||
if (param->type_ == ActivationType::RELU) {
|
||||
ret = Fp32Relu(input_addr, length, output_addr);
|
||||
} else if (param->type_ == ActivationType::SIGMOID) {
|
||||
ret = Sigmoid(input_addr, length, output_addr);
|
||||
} else if (param->type_ == ActivationType::RELU6) {
|
||||
ret = Fp32Relu6(input_addr, length, output_addr);
|
||||
} else if (param->type_ == ActivationType::LEAKY_RELU) {
|
||||
float alpha = param->alpha_;
|
||||
ret = LRelu(input_addr, length, output_addr, alpha);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport activation type " << param->type_;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_
|
||||
|
||||
#include "internal/include/model.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator);
|
||||
|
||||
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "internal/src/kernel/fp32/arithmetic_self.h"
|
||||
#include "internal/include/errorcode.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "nnacl/fp32/arithmetic_self.h"
|
||||
|
||||
int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator) {
|
||||
size_t data_size = in_tensors[0]->ElementsNum();
|
||||
OpParameter *param = node->primitive_;
|
||||
int ret;
|
||||
if (param->type_ == KernelType::Log) {
|
||||
ret = ElementLog((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size);
|
||||
} else if (param->type_ == KernelType::Neg) {
|
||||
ret = ElementNegative((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_
|
||||
|
||||
#include "internal/include/model.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator);
|
||||
|
||||
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_
|
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "internal/src/kernel/fp32/matmul.h"
|
||||
#include "nnacl/fp32/matmul.h"
|
||||
#include "internal/include/errorcode.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
typedef struct MatMulCPUKernelData {
|
||||
float *a_c12_ptr_;
|
||||
float *b_r8_ptr_;
|
||||
float *bias_ptr_;
|
||||
} MatMulCPUKernelData;
|
||||
|
||||
void MatMulInitMatrixA(float *src_ptr, float *dst_ptr, MatMulParameter *params) {
|
||||
for (int i = 0; i < params->batch; i++) {
|
||||
float *src = src_ptr + i * params->deep_ * params->row_;
|
||||
float *dst = dst_ptr + i * params->deep_ * params->row_12_;
|
||||
if (params->a_transpose_) {
|
||||
RowMajor2Row12Major(src, dst, params->deep_, params->row_);
|
||||
} else {
|
||||
RowMajor2Col12Major(src, dst, params->row_, params->deep_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulInitMatrixB(float *src_ptr, float *dst_ptr, MatMulParameter *params) {
|
||||
for (int i = 0; i < params->batch; i++) {
|
||||
float *src = src_ptr + i * params->deep_ * params->col_;
|
||||
float *dst = dst_ptr + i * params->deep_ * params->col_8_;
|
||||
if (params->b_transpose_) {
|
||||
RowMajor2Col8Major(src, dst, params->col_, params->deep_);
|
||||
} else {
|
||||
RowMajor2Row8Major(src, dst, params->deep_, params->col_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FreeMatMulKernelData(MatMulCPUKernelData *kernel_data, mindspore::lite::Allocator *allocator) {
|
||||
if (kernel_data == NULL) {
|
||||
return;
|
||||
}
|
||||
if (kernel_data->a_c12_ptr_ != NULL) {
|
||||
allocator->Free(kernel_data->a_c12_ptr_);
|
||||
kernel_data->a_c12_ptr_ = NULL;
|
||||
}
|
||||
|
||||
if (kernel_data->b_r8_ptr_ != NULL) {
|
||||
allocator->Free(kernel_data->b_r8_ptr_);
|
||||
kernel_data->b_r8_ptr_ = NULL;
|
||||
}
|
||||
|
||||
if (kernel_data->bias_ptr_ != NULL) {
|
||||
allocator->Free(kernel_data->bias_ptr_);
|
||||
kernel_data->bias_ptr_ = NULL;
|
||||
}
|
||||
free(kernel_data);
|
||||
}
|
||||
|
||||
int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator) {
|
||||
if (in_tensors[0]->data_ == NULL || in_tensors[1]->data_ ==NULL) {
|
||||
MS_LOG(ERROR) << "input data is NULL!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (allocator == NULL) {
|
||||
MS_LOG(ERROR) << "input allocator is NULL!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
int batch = 1;
|
||||
std::vector<int> a_shape = in_tensors[0]->shape_;
|
||||
std::vector<int> c_shape = out_tensors[0]->shape_;
|
||||
if (in_tensors.size() == 3) {
|
||||
std::vector<int> bias_shape = in_tensors[2]->shape_;
|
||||
if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) {
|
||||
MS_LOG(ERROR) << "The bias' dimension is not equal with column";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < a_shape.size() - 2; ++i) {
|
||||
batch *= a_shape[i];
|
||||
}
|
||||
|
||||
MatMulParameter *params = (MatMulParameter *)node->primitive_;
|
||||
params->batch = batch;
|
||||
params->row_ = c_shape[c_shape.size() - 2];
|
||||
params->col_ = c_shape[c_shape.size() - 1];
|
||||
params->deep_ = params->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
|
||||
params->row_12_ = UP_ROUND(params->row_, C12NUM);
|
||||
params->col_8_ = UP_ROUND(params->col_, 8);
|
||||
|
||||
MatMulCPUKernelData *kernel_data = (MatMulCPUKernelData *)malloc(sizeof(MatMulCPUKernelData));
|
||||
kernel_data->a_c12_ptr_
|
||||
= reinterpret_cast<float *>(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float)));
|
||||
if (kernel_data->a_c12_ptr_ == NULL) {
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float));
|
||||
|
||||
kernel_data->b_r8_ptr_
|
||||
= reinterpret_cast<float *>(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float)));
|
||||
if (kernel_data->b_r8_ptr_ == NULL) {
|
||||
FreeMatMulKernelData(kernel_data, allocator);
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(kernel_data->b_r8_ptr_, 0, params->col_8_ * params->deep_ * sizeof(float));
|
||||
|
||||
MatMulInitMatrixA((float *)in_tensors[0]->data_, kernel_data->a_c12_ptr_, params);
|
||||
MatMulInitMatrixB((float *)in_tensors[1]->data_, kernel_data->b_r8_ptr_, params);
|
||||
kernel_data->bias_ptr_ = (float *)(allocator->Malloc(params->col_8_ * sizeof(float)));
|
||||
if (kernel_data->bias_ptr_ == NULL) {
|
||||
FreeMatMulKernelData(kernel_data, allocator);
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(kernel_data->bias_ptr_, 0, params->col_8_ * sizeof(float));
|
||||
|
||||
if (in_tensors.size() == 3) {
|
||||
memcpy(kernel_data->bias_ptr_, in_tensors[2]->data_, params->col_ * sizeof(float));
|
||||
}
|
||||
auto c_src = (float *)out_tensors[0]->data_;
|
||||
for (int i = 0; i < params->batch; ++i) {
|
||||
float *a_ptr = kernel_data->a_c12_ptr_ + i * params->row_12_ * params->deep_;
|
||||
float *b_ptr = kernel_data->b_r8_ptr_ + i * params->deep_ * params->col_8_;
|
||||
float *c_ptr = c_src + i * params->row_ * params->col_;
|
||||
MatMulOpt(a_ptr, b_ptr, c_ptr, kernel_data->bias_ptr_, ActType_No, params->deep_, params->row_, params->col_,
|
||||
params->col_, OutType_Nhwc);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_
|
||||
|
||||
#include "internal/include/model.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator);
|
||||
|
||||
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "internal/src/kernel/fp32_grad/activation_grad.h"
|
||||
#include "internal/include/errorcode.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "nnacl/fp32_grad/activation_grad.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator) {
|
||||
ActivationGradParameter *param = (ActivationGradParameter *)node->primitive_;
|
||||
int ret = RET_OK;
|
||||
size_t length = in_tensors[0]->ElementsNum();
|
||||
float *dy_data = (float *)in_tensors[0]->data_;
|
||||
float *x_data = (float *)in_tensors[1]->data_;
|
||||
float *dx_data = (float *)(float *)out_tensors[0]->data_;
|
||||
if (param->type_ == ActivationType::RELU) {
|
||||
ret = ReluGrad(dy_data, x_data, length, dx_data);
|
||||
} else if (param->type_ == ActivationType::SIGMOID) {
|
||||
ret = SigmoidGrad(dy_data, x_data, length, dx_data);
|
||||
} else if (param->type_ == ActivationType::RELU6) {
|
||||
ret = Relu6Grad(dy_data, x_data, length, dx_data);
|
||||
} else if (param->type_ == ActivationType::LEAKY_RELU) {
|
||||
float alpha = param->alpha_;
|
||||
ret = LReluGrad(dy_data, x_data, length, dx_data, alpha);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport activation type " << param->type_;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_
|
||||
|
||||
#include "internal/include/model.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator);
|
||||
|
||||
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h"
|
||||
#include "internal/include/errorcode.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "nnacl/fp32/arithmetic_self.h"
|
||||
#include "nnacl/fp32/arithmetic.h"
|
||||
|
||||
int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator) {
|
||||
size_t data_size = in_tensors[0]->ElementsNum();
|
||||
OpParameter *param = node->primitive_;
|
||||
float *dy_data = (float *)in_tensors[0]->data_;
|
||||
float *x_data = (float *)in_tensors[1]->data_;
|
||||
float *dx_data = (float *)(float *)out_tensors[0]->data_;
|
||||
int ret;
|
||||
if (param->type_ == KernelType::LogGrad) {
|
||||
ret = ElementDiv(dy_data, x_data, dx_data, data_size);
|
||||
} else if (param->type_ == KernelType::NegGrad) {
|
||||
ret = ElementNegative(dy_data, dx_data, data_size);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_
|
||||
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_
|
||||
|
||||
#include "internal/include/model.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
|
||||
mindspore::lite::Allocator *allocator);
|
||||
|
||||
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_
|
|
@ -17,6 +17,13 @@
|
|||
#include "internal/include/model.h"
|
||||
#include "internal/include/ms_tensor.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
#include "internal/include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "internal/src/kernel/fp32/activation.h"
|
||||
#include "internal/src/kernel/fp32/arithmetic_self.h"
|
||||
#include "internal/src/kernel/fp32/matmul.h"
|
||||
#include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h"
|
||||
#include "internal/src/kernel/fp32_grad/activation_grad.h"
|
||||
|
||||
static Context *g_Ctx;
|
||||
static Model *g_Model;
|
||||
|
@ -58,11 +65,56 @@ TensorPtrVector LiteSession::GetOutputs() const {
|
|||
|
||||
int LiteSession::RunGraph() {
|
||||
// invoke nnacl kernel
|
||||
return 0;
|
||||
NodePtrVector nodes = g_Model->nodes_;
|
||||
size_t nodes_size = nodes.size();
|
||||
for (size_t i = 0; i < nodes_size; ++i) {
|
||||
auto node = nodes[i];
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "node's primitive is NULL!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
TensorPtrVector in_tensors;
|
||||
for (size_t j = 0; j < node->input_indices_.size(); ++j) {
|
||||
in_tensors.push_back(g_Model->all_tensors_[node->input_indices_[j]]);
|
||||
}
|
||||
TensorPtrVector out_tensors;
|
||||
for (size_t j = 0; j < node->output_indices_.size(); ++j) {
|
||||
out_tensors.push_back(g_Model->all_tensors_[node->output_indices_[j]]);
|
||||
}
|
||||
int type = node->primitive_->type_;
|
||||
int ret = RET_ERROR;
|
||||
switch (type) {
|
||||
case KernelType::MatMul:
|
||||
ret = DoMatMul(in_tensors, out_tensors, node, &allocator);
|
||||
break;
|
||||
case KernelType::Activation:
|
||||
ret = DoActivation(in_tensors, out_tensors, node, &allocator);
|
||||
break;
|
||||
case KernelType::Log:
|
||||
case KernelType::Neg:
|
||||
ret = DoArithmeticSelf(in_tensors, out_tensors, node, &allocator);
|
||||
break;
|
||||
case KernelType::LogGrad:
|
||||
case KernelType::NegGrad:
|
||||
ret = DoArithmeticGradSelf(in_tensors, out_tensors, node, &allocator);
|
||||
break;
|
||||
case KernelType::ActivationGrad:
|
||||
ret = DoActivationGrad(in_tensors, out_tensors, node, &allocator);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport kernel type: " << type;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "run kernel fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); }
|
||||
|
||||
MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; }
|
||||
|
||||
int LiteSession::Resize(const TensorPtrVector &inputs) { return 0; }
|
||||
int LiteSession::Resize(const TensorPtrVector &inputs, Int32VectorVector dims) { return 0; }
|
||||
|
|
Loading…
Reference in New Issue