forked from mindspore-Ecosystem/mindspore
!22703 [MSLITE][DEVELOP] code review of op: lstm, layer_norm, deconv, etc; move elu op to activation
Merge pull request !22703 from yangruoqi713/master
This commit is contained in:
commit
94f010aee6
|
@ -67,34 +67,41 @@ void LayerNormGammaAndBetaFp16(float16_t *dst, const float16_t *src, const float
|
|||
int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data,
|
||||
float16_t *dst_data, float16_t *out_mean, float16_t *out_deno, LayerNormParameter *param,
|
||||
size_t task_id) {
|
||||
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL || out_mean == NULL ||
|
||||
out_deno == NULL) {
|
||||
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_);
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_);
|
||||
int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_);
|
||||
int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_);
|
||||
for (int i = task_id * step; i < thread_end; i++) {
|
||||
const float16_t *src_norm = src_data + i * param->norm_inner_size_;
|
||||
float16_t *dst_norm = dst_data + i * param->norm_inner_size_;
|
||||
out_mean[i] = 0.0f;
|
||||
out_deno[i] = 0.0f;
|
||||
int ret = LayerNormMeanAndSquareFp16(src_norm, param->norm_inner_size_, &out_mean[i], &out_deno[i]);
|
||||
float16_t cur_mean = 0.0f;
|
||||
float16_t cur_deno = 0.0f;
|
||||
int ret = LayerNormMeanAndSquareFp16(src_norm, param->norm_inner_size_, &cur_mean, &cur_deno);
|
||||
if (ret != NNACL_OK) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
const float16_t deno = 1 / sqrtf(out_deno[i] - out_mean[i] * out_mean[i] + param->epsilon_);
|
||||
if (out_mean != NULL) {
|
||||
out_mean[i] = cur_mean;
|
||||
}
|
||||
if (out_deno != NULL) {
|
||||
out_deno[i] = cur_deno;
|
||||
}
|
||||
const float16_t deno = 1 / sqrtf(cur_deno - cur_mean * cur_mean + param->epsilon_);
|
||||
if (param->norm_outer_size_ <= param->params_outer_size_) {
|
||||
for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) {
|
||||
const float16_t *src_param = src_norm + x * param->params_inner_size_;
|
||||
float16_t *dst_param = dst_norm + x * param->params_inner_size_;
|
||||
LayerNormGammaAndBetaFp16(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, out_mean[i],
|
||||
LayerNormGammaAndBetaFp16(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean,
|
||||
deno);
|
||||
}
|
||||
} else {
|
||||
int x = i / param->params_outer_size_;
|
||||
const float16_t *gamma = gamma_data + x * param->norm_inner_size_;
|
||||
const float16_t *beta = beta_data + x * param->norm_inner_size_;
|
||||
LayerNormGammaAndBetaFp16(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, out_mean[i], deno);
|
||||
LayerNormGammaAndBetaFp16(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -284,3 +284,22 @@ int Softplus(const float *src, int length, float *dst) {
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Elu(const float *src, int length, float *dst, float alpha) {
|
||||
int i = 0;
|
||||
#if defined(ENABLE_ARM)
|
||||
MS_FLOAT32X4 one = MS_MOVQ_F32(1.0f);
|
||||
for (; i <= length - 4; i += 4) {
|
||||
MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
|
||||
MS_FLOAT32X4 exp_tmp = VexpFp32(src_tmp); // exp(x)
|
||||
exp_tmp = MS_SUBQ_F32(exp_tmp, one); // exp(x) - 1
|
||||
MS_FLOAT32X4 elu_tmp = MS_MULQ_N_F32(exp_tmp, alpha);
|
||||
MS_UINT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
|
||||
MS_STQ_F32(dst + i, MS_BLENDQ_F32(elu_tmp, src_tmp, mask));
|
||||
}
|
||||
#endif
|
||||
for (; i < length; ++i) {
|
||||
dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ int HSwish(const float *src, int length, float *dst);
|
|||
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
|
||||
int Gelu(const float *src, int length, float *dst, bool approximate);
|
||||
int Softplus(const float *src, int length, float *dst);
|
||||
int Elu(const float *src, int length, float *dst, float alpha);
|
||||
|
||||
float TanhOpt(float src);
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/elu_fp32.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
void Calculate_Data(const float *input_data, float *output_data, int num, const EluParameter *parameter) {
|
||||
output_data[num] = input_data[num] < 0 ? parameter->alpha_ * expm1(input_data[num]) : input_data[num];
|
||||
}
|
||||
|
||||
int Elu(const float *input_data, float *output_data, const EluParameter *parameter, int task_id) {
|
||||
for (size_t i = task_id; i < parameter->in_size_; i += parameter->op_parameter_.thread_num_) {
|
||||
Calculate_Data(input_data, output_data, i, parameter);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_ELU_H_
|
||||
#define MINDSPORE_NNACL_FP32_ELU_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct EluParameter {
|
||||
OpParameter op_parameter_;
|
||||
// primitive parameter
|
||||
float alpha_;
|
||||
|
||||
// shape correlative
|
||||
int in_size_;
|
||||
} EluParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int Elu(const float *input_data, float *output_data, const EluParameter *parameter, int task_id);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_ELU_H_
|
|
@ -69,8 +69,7 @@ void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data
|
|||
|
||||
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean,
|
||||
float *out_deno, LayerNormParameter *param, size_t task_id) {
|
||||
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL || out_mean == NULL ||
|
||||
out_deno == NULL) {
|
||||
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_);
|
||||
|
@ -80,25 +79,30 @@ int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_
|
|||
for (int i = task_id * step; i < thread_end; i++) {
|
||||
const float *src_norm = src_data + i * param->norm_inner_size_;
|
||||
float *dst_norm = dst_data + i * param->norm_inner_size_;
|
||||
out_mean[i] = 0.0f;
|
||||
out_deno[i] = 0.0f;
|
||||
int ret = LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &out_mean[i], &out_deno[i]);
|
||||
float cur_mean = 0.0f;
|
||||
float cur_deno = 0.0f;
|
||||
int ret = LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &cur_mean, &cur_deno);
|
||||
if (ret != NNACL_OK) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
const float deno = 1 / sqrtf(out_deno[i] - out_mean[i] * out_mean[i] + param->epsilon_);
|
||||
if (out_mean != NULL) {
|
||||
out_mean[i] = cur_mean;
|
||||
}
|
||||
if (out_deno != NULL) {
|
||||
out_deno[i] = cur_deno;
|
||||
}
|
||||
const float deno = 1 / sqrtf(cur_deno - cur_mean * cur_mean + param->epsilon_);
|
||||
if (param->norm_outer_size_ <= param->params_outer_size_) {
|
||||
for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) {
|
||||
const float *src_param = src_norm + x * param->params_inner_size_;
|
||||
float *dst_param = dst_norm + x * param->params_inner_size_;
|
||||
LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, out_mean[i],
|
||||
deno);
|
||||
LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, deno);
|
||||
}
|
||||
} else {
|
||||
int x = i / param->params_outer_size_;
|
||||
const float *gamma = gamma_data + x * param->norm_inner_size_;
|
||||
const float *beta = beta_data + x * param->norm_inner_size_;
|
||||
LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, out_mean[i], deno);
|
||||
LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -30,14 +30,23 @@ int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
const TensorC *input = inputs[0];
|
||||
TensorC *output = outputs[0];
|
||||
SetDataTypeFormat(output, input);
|
||||
|
||||
LayerNormParameter *param = (LayerNormParameter *)parameter;
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
if (input->shape_size_ > MAX_SHAPE_SIZE) {
|
||||
|
||||
LayerNormParameter *param = (LayerNormParameter *)parameter;
|
||||
NNACL_CHECK_NULL_RETURN_ERR(param);
|
||||
if (input->shape_size_ > COMM_SHAPE_SIZE) {
|
||||
return NNACL_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (param->begin_params_axis_ < (-1 * (int)(input->shape_size_)) ||
|
||||
param->begin_params_axis_ >= (int)(input->shape_size_)) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
if (param->begin_norm_axis_ < (-1 * (int)(input->shape_size_)) ||
|
||||
param->begin_norm_axis_ >= (int)(input->shape_size_)) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
param->begin_norm_axis_ =
|
||||
param->begin_norm_axis_ < 0 ? param->begin_norm_axis_ + ((int)(input->shape_size_)) : param->begin_norm_axis_;
|
||||
SetShapeTensor(output, input);
|
||||
|
|
|
@ -26,9 +26,8 @@ int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
|
|||
|
||||
const TensorC *input = inputs[0];
|
||||
TensorC *output = outputs[0];
|
||||
SetDataTypeFormat(output, input);
|
||||
|
||||
output->data_type_ = input->data_type_;
|
||||
output->format_ = input->format_;
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
@ -36,6 +35,11 @@ int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
|
|||
return NNACL_ERR;
|
||||
}
|
||||
SetShapeTensor(output, input);
|
||||
SoftmaxParameter *param = (SoftmaxParameter *)parameter;
|
||||
NNACL_CHECK_NULL_RETURN_ERR(param);
|
||||
if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ >= (int)(input->shape_size_)) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
|
|||
}
|
||||
|
||||
LstmParameter *param = (LstmParameter *)parameter;
|
||||
NNACL_CHECK_NULL_RETURN_ERR(param);
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
|
|
@ -37,6 +37,8 @@ int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *
|
|||
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_);
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_);
|
||||
|
||||
int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_);
|
||||
int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_);
|
||||
|
|
|
@ -18,11 +18,6 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#define GAMMA_INDEX 1
|
||||
#define BETA_INDEX 2
|
||||
#define MEAN_INDEX 1
|
||||
#define VARIANCE_INDEX 2
|
||||
|
||||
enum ElementwiseMode { ELEMENTWISE_NOT = 0, ELEMENTWISE_PER_CHANNEL = 1, ELEMENTWISE_PER_NUM = 2 };
|
||||
typedef struct LayerNormParameter {
|
||||
// Primitive parameter
|
||||
|
|
|
@ -46,14 +46,6 @@
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define MS_CHECK_PTR_IF_NULL(ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
|
||||
return; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MS_CHECK_RET_CODE(code, msg) \
|
||||
do { \
|
||||
if ((code) != RET_OK) { \
|
||||
|
|
|
@ -100,7 +100,7 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const PoolingParam
|
|||
MS_CHECK_PTR_IF_NULL(in_quant_args);
|
||||
MS_CHECK_PTR_IF_NULL(out_quant_args);
|
||||
|
||||
code << "static QuantArg " << in_quant_name << " = " << *out_quant_args << ";\n";
|
||||
code << "static QuantArg " << in_quant_name << " = " << *in_quant_args << ";\n";
|
||||
code << "static QuantArg " << out_quant_name << " = " << *out_quant_args << ";\n";
|
||||
|
||||
code << "static QuantArg *" << quant_name << "[2] = {"
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp32/elu_fp32.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
using mindspore::schema::PrimitiveType_Elu;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateEluParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_Elu();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EluParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(EluParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->alpha_ = value->alpha();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
REG_POPULATE(PrimitiveType_Elu, PopulateEluParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "schema/model_v0_generated.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32/elu_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
OpParameter *PopulateEluParameter(const void *prim) {
|
||||
auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto elu_prim = primitive->value_as_Elu();
|
||||
if (elu_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "elu_prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto *elu_parameter = reinterpret_cast<EluParameter *>(malloc(sizeof(EluParameter)));
|
||||
if (elu_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EluParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(elu_parameter, 0, sizeof(EluParameter));
|
||||
elu_parameter->op_parameter_.type_ = schema::PrimitiveType_Elu;
|
||||
|
||||
elu_parameter->alpha_ = elu_prim->alpha();
|
||||
return reinterpret_cast<OpParameter *>(elu_parameter);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Registry g_eluV0ParameterRegistry(schema::v0::PrimitiveType_Elu, PopulateEluParameter, SCHEMA_V0);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -40,6 +40,7 @@ int SoftmaxBaseCPUKernel::Init() {
|
|||
|
||||
int SoftmaxBaseCPUKernel::ReSize() {
|
||||
auto input_tensor = in_tensors_.front();
|
||||
CHECK_NULL_RETURN(input_tensor);
|
||||
auto in_shape = input_tensor->shape();
|
||||
auto in_dims = in_shape.size();
|
||||
int ele_size = 1;
|
||||
|
|
|
@ -28,6 +28,9 @@ using mindspore::schema::PrimitiveType_LayerNormFusion;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
int LayerNormFp16CPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 3);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(param_);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -35,7 +38,9 @@ int LayerNormFp16CPUKernel::Init() {
|
|||
}
|
||||
|
||||
int LayerNormFp16CPUKernel::ReSize() {
|
||||
auto shape = in_tensors_.front()->shape();
|
||||
auto input = in_tensors_.front();
|
||||
CHECK_NULL_RETURN(input);
|
||||
auto shape = input->shape();
|
||||
param_->begin_norm_axis_ =
|
||||
param_->begin_norm_axis_ > 0 ? param_->begin_norm_axis_ : param_->begin_norm_axis_ + shape.size();
|
||||
param_->begin_params_axis_ =
|
||||
|
@ -62,7 +67,7 @@ int LayerNormFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LayerNormFp16CPUKernel::DoLayerNormFp16(int thread_id) {
|
||||
int ret = LayerNormFp16(src_data_, gamma_data_, beta_data_, dst_data_, mean_data_, var_data_, param_, thread_id);
|
||||
auto ret = LayerNormFp16(src_data_, gamma_data_, beta_data_, dst_data_, mean_data_, var_data_, param_, thread_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]";
|
||||
return ret;
|
||||
|
@ -72,6 +77,7 @@ int LayerNormFp16CPUKernel::DoLayerNormFp16(int thread_id) {
|
|||
|
||||
int LayerNormFp16Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LayerNormFp16CPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto ret = kernel->DoLayerNormFp16(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LayerNormFp16Run error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
|
@ -81,25 +87,25 @@ int LayerNormFp16Run(void *cdata, int task_id, float lhs_scale, float rhs_scale)
|
|||
}
|
||||
|
||||
int LayerNormFp16CPUKernel::Run() {
|
||||
int ret = RET_OK;
|
||||
src_data_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
|
||||
gamma_data_ = reinterpret_cast<float16_t *>(in_tensors_.at(GAMMA_INDEX)->data_c());
|
||||
beta_data_ = reinterpret_cast<float16_t *>(in_tensors_.at(BETA_INDEX)->data_c());
|
||||
CHECK_NULL_RETURN(src_data_);
|
||||
gamma_data_ = reinterpret_cast<float16_t *>(in_tensors_.at(1)->data_c());
|
||||
CHECK_NULL_RETURN(gamma_data_);
|
||||
beta_data_ = reinterpret_cast<float16_t *>(in_tensors_.at(2)->data_c());
|
||||
CHECK_NULL_RETURN(beta_data_);
|
||||
dst_data_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
|
||||
if (out_tensors_.size() == kInputSize2) {
|
||||
mean_data_ = reinterpret_cast<float16_t *>(out_tensors_.at(MEAN_INDEX)->data_c());
|
||||
var_data_ = reinterpret_cast<float16_t *>(out_tensors_.at(VARIANCE_INDEX)->data_c());
|
||||
} else {
|
||||
mean_data_ =
|
||||
reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(param_->norm_outer_size_ * sizeof(float16_t)));
|
||||
var_data_ =
|
||||
reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(param_->norm_outer_size_ * sizeof(float16_t)));
|
||||
}
|
||||
ret = ParallelLaunch(this->ms_context_, LayerNormFp16Run, this, op_parameter_->thread_num_);
|
||||
if (out_tensors_.size() != kInputSize2) {
|
||||
ms_context_->allocator->Free(mean_data_);
|
||||
ms_context_->allocator->Free(var_data_);
|
||||
CHECK_NULL_RETURN(dst_data_);
|
||||
|
||||
if (out_tensors_.size() == 3) {
|
||||
mean_data_ = reinterpret_cast<float16_t *>(out_tensors_.at(1)->data_c());
|
||||
CHECK_NULL_RETURN(mean_data_);
|
||||
var_data_ = reinterpret_cast<float16_t *>(out_tensors_.at(2)->data_c());
|
||||
CHECK_NULL_RETURN(var_data_);
|
||||
} else if (out_tensors_.size() != 1) {
|
||||
MS_LOG(ERROR) << "LayerNorm should have 1 or 3 output tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->ms_context_, LayerNormFp16Run, this, op_parameter_->thread_num_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,13 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_LogSoftmax;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
LogSoftmaxFp16CPUKernel::~LogSoftmaxFp16CPUKernel() {
|
||||
if (tmp_data_ != nullptr) {
|
||||
free(tmp_data_);
|
||||
tmp_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int LogSoftmaxFp16CPUKernel::Init() {
|
||||
auto ret = SoftmaxBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -73,19 +80,23 @@ int LogSoftmaxFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LogSoftmaxFp16CPUKernel::DoLogSoftmaxLastAxis(int task_id) {
|
||||
MS_CHECK_FALSE(op_parameter_->thread_num_ == 0, RET_ERROR);
|
||||
int unit = UP_DIV(out_plane_size_, op_parameter_->thread_num_);
|
||||
int begin = task_id * unit;
|
||||
int end = MSMIN(begin + unit, out_plane_size_);
|
||||
int channel = softmax_param_->input_shape_[softmax_param_->axis_];
|
||||
int offset = begin * channel;
|
||||
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(kInputIndex)->MutableData());
|
||||
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(kInputIndex)->data_c());
|
||||
CHECK_NULL_RETURN(input_ptr);
|
||||
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(kOutputIndex)->data_c());
|
||||
CHECK_NULL_RETURN(output_ptr);
|
||||
LogSoftmaxLastAxisFp16(input_ptr + offset, output_ptr + offset, tmp_data_ + offset, end - begin, channel);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LogSoftmaxLastAxisFp16Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LogSoftmaxFp16CPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto ret = kernel->DoLogSoftmaxLastAxis(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLogSoftmaxLastAxisFp16 error task_id: " << task_id << ", ret: " << ret;
|
||||
|
@ -101,15 +112,11 @@ int LogSoftmaxFp16CPUKernel::Run() {
|
|||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
MS_ASSERT(input_tensor);
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
MS_ASSERT(output_tensor);
|
||||
input_fp16_ = reinterpret_cast<float16_t *>(input_tensor->data_c());
|
||||
MS_ASSERT(input_fp16_);
|
||||
output_fp16_ = reinterpret_cast<float16_t *>(output_tensor->data_c());
|
||||
MS_ASSERT(output_fp16_);
|
||||
MS_ASSERT(tmp_data_);
|
||||
input_fp16_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(input_fp16_);
|
||||
output_fp16_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(output_fp16_);
|
||||
CHECK_NULL_RETURN(tmp_data_);
|
||||
LogSoftmaxFp16(input_fp16_, output_fp16_, tmp_data_, softmax_param_);
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -28,11 +28,7 @@ class LogSoftmaxFp16CPUKernel : public SoftmaxBaseCPUKernel {
|
|||
LogSoftmaxFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx), tmp_data_(nullptr) {}
|
||||
~LogSoftmaxFp16CPUKernel() override {
|
||||
if (tmp_data_ != nullptr) {
|
||||
free(tmp_data_);
|
||||
}
|
||||
}
|
||||
~LogSoftmaxFp16CPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
|
|
@ -66,14 +66,12 @@ void LstmFp16CPUKernel::FreeRunBuffer() {
|
|||
|
||||
int LstmFp16CPUKernel::InitParam() {
|
||||
auto input = in_tensors_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
std::vector<int> in_shape = input->shape();
|
||||
lstm_param_->seq_len_ = in_shape.at(0);
|
||||
lstm_param_->batch_ = in_shape.at(1);
|
||||
lstm_param_->input_size_ = in_shape.at(2);
|
||||
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
lstm_param_->hidden_size_ = w_shape.at(1) / gate_num;
|
||||
|
||||
|
@ -95,8 +93,8 @@ int LstmFp16CPUKernel::InitInputWeightBias() {
|
|||
// weight -- row: hidden_size; col: input_size, need transpose
|
||||
// result -- row: seq_len * batch; col: hidden_size
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
MS_ASSERT(weight_i->data_c() != nullptr);
|
||||
auto weight_i_data = weight_i->data_c();
|
||||
CHECK_NULL_RETURN(weight_i_data);
|
||||
weight_i_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
|
@ -104,10 +102,10 @@ int LstmFp16CPUKernel::InitInputWeightBias() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (weight_i->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i->data_c()), weight_batch_,
|
||||
PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i_data), weight_batch_,
|
||||
lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_);
|
||||
} else if (weight_i->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i->data_c()), weight_batch_,
|
||||
PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i_data), weight_batch_,
|
||||
lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm.";
|
||||
|
@ -116,8 +114,8 @@ int LstmFp16CPUKernel::InitInputWeightBias() {
|
|||
|
||||
// input bias
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
MS_ASSERT(bias->data_c() != nullptr);
|
||||
auto bias_data = bias->data_c();
|
||||
CHECK_NULL_RETURN(bias_data);
|
||||
input_bias_ =
|
||||
reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t)));
|
||||
if (input_bias_ == nullptr) {
|
||||
|
@ -126,11 +124,11 @@ int LstmFp16CPUKernel::InitInputWeightBias() {
|
|||
}
|
||||
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t));
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias_data), weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias_data), weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
|
@ -144,8 +142,8 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
|
|||
// weight -- row: hidden_size; col: hidden_size, need transpose
|
||||
// result -- row: batch; col: hidden_size
|
||||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
MS_ASSERT(weight_h->data_c() != nullptr);
|
||||
auto weight_h_data = weight_h->data_c();
|
||||
CHECK_NULL_RETURN(weight_h_data);
|
||||
weight_h_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
|
@ -155,10 +153,10 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
|
|||
|
||||
if (!is_vec_) {
|
||||
if (weight_h->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h->data_c()), weight_batch_,
|
||||
PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h_data), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_);
|
||||
} else if (weight_h->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_batch_,
|
||||
PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm.";
|
||||
|
@ -166,9 +164,9 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
|
|||
}
|
||||
} else {
|
||||
if (weight_h->data_type() == kNumberTypeFloat32) {
|
||||
Float32ToFloat16(reinterpret_cast<float *>(weight_h->data_c()), weight_h_ptr_, weight_h->ElementsNum());
|
||||
Float32ToFloat16(reinterpret_cast<float *>(weight_h_data), weight_h_ptr_, weight_h->ElementsNum());
|
||||
} else if (weight_h->data_type() == kNumberTypeFloat16) {
|
||||
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_h->Size());
|
||||
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_h->Size());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
|
@ -177,8 +175,8 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
|
|||
|
||||
// state bias
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
MS_ASSERT(bias->data_c() != nullptr);
|
||||
auto bias_data = bias->data_c();
|
||||
CHECK_NULL_RETURN(bias_data);
|
||||
state_bias_ =
|
||||
reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t)));
|
||||
if (state_bias_ == nullptr) {
|
||||
|
@ -187,11 +185,11 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
|
|||
}
|
||||
memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t));
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
auto state_bias_data = reinterpret_cast<float *>(bias->data_c()) + gate_num * lstm_param_->hidden_size_;
|
||||
auto state_bias_data = reinterpret_cast<float *>(bias_data) + gate_num * lstm_param_->hidden_size_;
|
||||
PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->state_col_align_, lstm_param_->bidirectional_);
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
auto state_bias_data = reinterpret_cast<float16_t *>(bias->data_c()) + gate_num * lstm_param_->hidden_size_;
|
||||
auto state_bias_data = reinterpret_cast<float16_t *>(bias_data) + gate_num * lstm_param_->hidden_size_;
|
||||
PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->state_col_align_, lstm_param_->bidirectional_);
|
||||
} else {
|
||||
|
@ -203,7 +201,14 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
|
|||
|
||||
int LstmFp16CPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 6);
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
CHECK_NULL_RETURN(in_tensors_.at(i));
|
||||
}
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 3);
|
||||
for (size_t i = 0; i < out_tensors_.size(); i++) {
|
||||
CHECK_NULL_RETURN(out_tensors_.at(i));
|
||||
}
|
||||
CHECK_NULL_RETURN(lstm_param_);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -288,26 +293,23 @@ int LstmFp16CPUKernel::MallocRunBuffer() {
|
|||
}
|
||||
|
||||
int LstmFp16CPUKernel::Run() {
|
||||
auto input = in_tensors_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto hidden_state = in_tensors_.at(4);
|
||||
MS_ASSERT(hidden_state != nullptr);
|
||||
MS_ASSERT(hidden_state->data_c() != nullptr);
|
||||
auto cell_state = in_tensors_.at(5);
|
||||
MS_ASSERT(cell_state != nullptr);
|
||||
MS_ASSERT(cell_state->data_c() != nullptr);
|
||||
auto output = out_tensors_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto input = in_tensors_.at(0);
|
||||
auto input_ptr = reinterpret_cast<float16_t *>(input->data_c());
|
||||
MS_ASSERT(input_ptr != nullptr);
|
||||
CHECK_NULL_RETURN(input_ptr);
|
||||
auto output = out_tensors_.at(0);
|
||||
auto output_ptr = reinterpret_cast<float16_t *>(output->data_c());
|
||||
MS_ASSERT(output_ptr != nullptr);
|
||||
CHECK_NULL_RETURN(output_ptr);
|
||||
|
||||
auto hidden_state = in_tensors_.at(4);
|
||||
CHECK_NULL_RETURN(hidden_state->data_c());
|
||||
auto cell_state = in_tensors_.at(5);
|
||||
CHECK_NULL_RETURN(cell_state->data_c());
|
||||
|
||||
auto output_hidden_state = out_tensors_[1];
|
||||
MS_ASSERT(output_hidden_state->data_c() != nullptr);
|
||||
CHECK_NULL_RETURN(output_hidden_state->data_c());
|
||||
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t));
|
||||
auto output_cell_state = out_tensors_[2];
|
||||
MS_ASSERT(output_cell_state->data_c());
|
||||
CHECK_NULL_RETURN(output_cell_state->data_c());
|
||||
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float16_t));
|
||||
|
||||
auto ret = MallocRunBuffer();
|
||||
|
@ -316,10 +318,10 @@ int LstmFp16CPUKernel::Run() {
|
|||
FreeRunBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
MS_ASSERT(input_bias_);
|
||||
MS_ASSERT(state_bias_);
|
||||
CHECK_NULL_RETURN(weight_i_ptr_);
|
||||
CHECK_NULL_RETURN(weight_h_ptr_);
|
||||
CHECK_NULL_RETURN(input_bias_);
|
||||
CHECK_NULL_RETURN(state_bias_);
|
||||
LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float16_t *>(output_cell_state->data_c()), buffer_, lstm_param_);
|
||||
|
|
|
@ -41,7 +41,7 @@ int ActivationCPUKernel::Init() {
|
|||
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH &&
|
||||
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HSIGMOID &&
|
||||
type_ != schema::ActivationType_HARD_TANH && type_ != schema::ActivationType_GELU &&
|
||||
type_ != schema::ActivationType_SOFTPLUS) {
|
||||
type_ != schema::ActivationType_SOFTPLUS && type_ != schema::ActivationType_ELU) {
|
||||
MS_LOG(ERROR) << "Activation fp32 not support type: " << type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -90,6 +90,8 @@ int ActivationCPUKernel::DoActivation(int task_id) {
|
|||
ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id, true);
|
||||
} else if (type_ == schema::ActivationType_SOFTPLUS) {
|
||||
ret = Softplus(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_ELU) {
|
||||
ret = Elu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -32,6 +32,7 @@ int CumsumLaunch(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto kernel = reinterpret_cast<CumSumCPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto input_tensor = kernel->in_tensors().at(0);
|
||||
int ret;
|
||||
if (input_tensor->data_type() == kNumberTypeFloat32) {
|
||||
|
@ -47,6 +48,9 @@ int CumsumLaunch(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
|||
} // namespace
|
||||
|
||||
int CumSumCPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 2);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(param_);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -54,20 +58,19 @@ int CumSumCPUKernel::Init() {
|
|||
}
|
||||
|
||||
int CumSumCPUKernel::ReSize() {
|
||||
MS_ASSERT(in_tensors_.size() == kInputSize1);
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(input_tensor);
|
||||
auto axis_tensor = in_tensors_.at(1);
|
||||
CHECK_NULL_RETURN(axis_tensor);
|
||||
int *axis_data = reinterpret_cast<int *>(axis_tensor->data_c());
|
||||
if (axis_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Cumsum axis nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CHECK_NULL_RETURN(axis_data);
|
||||
|
||||
param_->axis_ = *axis_data;
|
||||
if (param_->axis_ < 0) {
|
||||
param_->axis_ += in_tensors_.at(0)->shape().size();
|
||||
param_->axis_ += input_tensor->shape().size();
|
||||
}
|
||||
if (static_cast<int>(in_tensors_.at(0)->shape().size()) <= param_->axis_) {
|
||||
MS_LOG(ERROR) << "axis " << param_->axis_ << " larger than in tensor rank " << in_tensors_.at(0)->shape().size();
|
||||
if (param_->axis_ < 0 || param_->axis_ >= static_cast<int>(input_tensor->shape().size())) {
|
||||
MS_LOG(ERROR) << "axis " << param_->axis_ << " error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_dim_ = 1;
|
||||
|
@ -79,25 +82,17 @@ int CumSumCPUKernel::ReSize() {
|
|||
for (int i = param_->axis_ + 1; i < static_cast<int>(input_tensor->shape().size()); ++i) {
|
||||
in_dim_ *= input_tensor->shape().at(i);
|
||||
}
|
||||
MS_CHECK_FALSE(op_parameter_->thread_num_ == 0, RET_ERROR);
|
||||
unit_ = UP_DIV(out_dim_, op_parameter_->thread_num_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CumSumCPUKernel::DoCumsum(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
float *input_data = reinterpret_cast<float *>(input_tensor->data_c());
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
MS_ASSERT(output_tensor != nullptr);
|
||||
float *output_data = reinterpret_cast<float *>(output_tensor->data_c());
|
||||
if (output_data == nullptr) {
|
||||
MS_LOG(ERROR) << "output data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float *input_data = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(input_data);
|
||||
float *output_data = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(output_data);
|
||||
|
||||
float *input = input_data + task_id * unit_ * axis_dim_ * in_dim_;
|
||||
int out_dim = MSMIN(out_dim_ - unit_ * task_id, unit_);
|
||||
float *output = output_data + task_id * unit_ * axis_dim_ * in_dim_;
|
||||
|
@ -110,20 +105,11 @@ int CumSumCPUKernel::DoCumsum(int task_id) {
|
|||
}
|
||||
|
||||
int CumSumCPUKernel::DoCumsumInt(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
int *input_data = reinterpret_cast<int *>(input_tensor->data_c());
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
MS_ASSERT(output_tensor != nullptr);
|
||||
int *output_data = reinterpret_cast<int *>(output_tensor->data_c());
|
||||
if (output_data == nullptr) {
|
||||
MS_LOG(ERROR) << "output data nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int *input_data = reinterpret_cast<int *>(in_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(input_data);
|
||||
int *output_data = reinterpret_cast<int *>(out_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(output_data);
|
||||
|
||||
int *input = input_data + task_id * unit_ * axis_dim_ * in_dim_;
|
||||
int out_dim = MSMIN(out_dim_ - unit_ * task_id, unit_);
|
||||
int *output = output_data + task_id * unit_ * axis_dim_ * in_dim_;
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/elu_fp32.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Elu;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int EluCPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 1);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int EluCPUKernel::ReSize() {
|
||||
elu_parameter_->in_size_ = in_tensors_.front()->ElementsNum();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int EluCPUKernel::DoExcute(int task_id) {
|
||||
auto input_addr = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
auto error_code = Elu(input_addr, output_addr, elu_parameter_, task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "EluCPUKernel DoExcute error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int EluRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto EluData = reinterpret_cast<EluCPUKernel *>(cdata);
|
||||
auto ret = EluData->DoExcute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "EluRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int EluCPUKernel::Run() {
|
||||
auto ret = ParallelLaunch(this->ms_context_, EluRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Elu error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Elu, LiteKernelCreator<EluCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ELU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ELU_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/fp32/elu_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class EluCPUKernel : public InnerKernel {
|
||||
public:
|
||||
EluCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
elu_parameter_ = reinterpret_cast<EluParameter *>(op_parameter_);
|
||||
}
|
||||
~EluCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExcute(int task_id);
|
||||
|
||||
private:
|
||||
EluParameter *elu_parameter_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ELU_H
|
|
@ -27,6 +27,7 @@ namespace mindspore::kernel {
|
|||
int EmbeddingLookupCPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 1);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(param_);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -43,14 +44,17 @@ int EmbeddingLookupCPUKernel::ReSize() {
|
|||
|
||||
param_->layer_num_ = 0;
|
||||
for (size_t i = 0; i < in_tensors_.size() - 1; ++i) {
|
||||
CHECK_NULL_RETURN(in_tensors_.at(i));
|
||||
param_->layer_num_ += in_tensors_[i]->shape()[0];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int EmbeddingLookupCPUKernel::DoExcute(int task_id) {
|
||||
auto ids_addr = reinterpret_cast<int *>(in_tensors_.back()->MutableData());
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
auto ids_addr = reinterpret_cast<int *>(in_tensors_.back()->data_c());
|
||||
CHECK_NULL_RETURN(ids_addr);
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.front()->data_c());
|
||||
CHECK_NULL_RETURN(output_addr);
|
||||
int error_code = EmbeddingLookup(input_addr_, ids_addr, output_addr, param_, task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "embedding lookup error error_code[" << error_code << "]";
|
||||
|
@ -61,6 +65,7 @@ int EmbeddingLookupCPUKernel::DoExcute(int task_id) {
|
|||
|
||||
int EmbeddingLookupRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<EmbeddingLookupCPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto ret = kernel->DoExcute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "EmbeddingLookupRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
|
@ -70,7 +75,6 @@ int EmbeddingLookupRun(void *cdata, int task_id, float lhs_scale, float rhs_scal
|
|||
}
|
||||
|
||||
int EmbeddingLookupCPUKernel::Run() {
|
||||
MS_ASSERT(ms_context_->allocator != nullptr);
|
||||
input_addr_ =
|
||||
reinterpret_cast<float *>(ms_context_->allocator->Malloc(sizeof(float) * param_->layer_size_ * param_->layer_num_));
|
||||
param_->is_regulated_ = reinterpret_cast<bool *>(ms_context_->allocator->Malloc(sizeof(bool) * param_->layer_num_));
|
||||
|
@ -84,7 +88,12 @@ int EmbeddingLookupCPUKernel::Run() {
|
|||
}
|
||||
int dest_loc = 0;
|
||||
for (size_t i = 0; i < in_tensors_.size() - 1; i++) {
|
||||
auto input_t = reinterpret_cast<float *>(in_tensors_.at(i)->MutableData());
|
||||
auto input_t = reinterpret_cast<float *>(in_tensors_.at(i)->data_c());
|
||||
if (input_t == nullptr) {
|
||||
MS_LOG(ERROR) << "Get input tensor data failed.";
|
||||
FreeRunBuff();
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(input_addr_ + dest_loc, input_t, sizeof(float) * in_tensors_.at(i)->ElementsNum());
|
||||
dest_loc += in_tensors_.at(i)->ElementsNum();
|
||||
}
|
||||
|
@ -98,8 +107,8 @@ int EmbeddingLookupCPUKernel::Run() {
|
|||
|
||||
void EmbeddingLookupCPUKernel::FreeRunBuff() {
|
||||
ms_context_->allocator->Free(input_addr_);
|
||||
ms_context_->allocator->Free(param_->is_regulated_);
|
||||
input_addr_ = nullptr;
|
||||
ms_context_->allocator->Free(param_->is_regulated_);
|
||||
param_->is_regulated_ = nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,9 @@ using mindspore::schema::PrimitiveType_LayerNormFusion;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
int LayerNormCPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 3);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(param_);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -36,7 +37,9 @@ int LayerNormCPUKernel::Init() {
|
|||
}
|
||||
|
||||
int LayerNormCPUKernel::ReSize() {
|
||||
auto shape = in_tensors_.front()->shape();
|
||||
auto input = in_tensors_.front();
|
||||
CHECK_NULL_RETURN(input);
|
||||
auto shape = input->shape();
|
||||
param_->begin_norm_axis_ =
|
||||
param_->begin_norm_axis_ > 0 ? param_->begin_norm_axis_ : param_->begin_norm_axis_ + shape.size();
|
||||
param_->begin_params_axis_ =
|
||||
|
@ -63,7 +66,7 @@ int LayerNormCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LayerNormCPUKernel::DoLayerNorm(int thread_id) {
|
||||
int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, mean_data_, var_data_, param_, thread_id);
|
||||
auto ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, mean_data_, var_data_, param_, thread_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]";
|
||||
return ret;
|
||||
|
@ -73,6 +76,7 @@ int LayerNormCPUKernel::DoLayerNorm(int thread_id) {
|
|||
|
||||
int LayerNormRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LayerNormCPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto ret = kernel->DoLayerNorm(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LayerNormRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
|
@ -82,23 +86,25 @@ int LayerNormRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
|||
}
|
||||
|
||||
int LayerNormCPUKernel::Run() {
|
||||
int ret = RET_OK;
|
||||
src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(src_data_);
|
||||
gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
|
||||
CHECK_NULL_RETURN(gamma_data_);
|
||||
beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c());
|
||||
CHECK_NULL_RETURN(beta_data_);
|
||||
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(dst_data_);
|
||||
|
||||
if (out_tensors_.size() == 3) {
|
||||
mean_data_ = reinterpret_cast<float *>(out_tensors_.at(1)->data_c());
|
||||
CHECK_NULL_RETURN(mean_data_);
|
||||
var_data_ = reinterpret_cast<float *>(out_tensors_.at(2)->data_c());
|
||||
} else {
|
||||
mean_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(param_->norm_outer_size_ * sizeof(float)));
|
||||
var_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(param_->norm_outer_size_ * sizeof(float)));
|
||||
}
|
||||
ret = ParallelLaunch(this->ms_context_, LayerNormRun, this, op_parameter_->thread_num_);
|
||||
if (out_tensors_.size() != 3) {
|
||||
ms_context_->allocator->Free(mean_data_);
|
||||
ms_context_->allocator->Free(var_data_);
|
||||
CHECK_NULL_RETURN(var_data_);
|
||||
} else if (out_tensors_.size() != 1) {
|
||||
MS_LOG(ERROR) << "LayerNorm should have 1 or 3 output tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->ms_context_, LayerNormRun, this, op_parameter_->thread_num_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,13 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_LogSoftmax;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
LogSoftmaxCPUKernel::~LogSoftmaxCPUKernel() {
|
||||
if (tmp_data_ != nullptr) {
|
||||
free(tmp_data_);
|
||||
tmp_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int LogSoftmaxCPUKernel::Init() {
|
||||
auto ret = SoftmaxBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -73,19 +80,23 @@ int LogSoftmaxCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LogSoftmaxCPUKernel::DoLogSoftmaxLastAxis(int task_id) {
|
||||
MS_CHECK_FALSE(op_parameter_->thread_num_ == 0, RET_ERROR);
|
||||
int unit = UP_DIV(out_plane_size_, op_parameter_->thread_num_);
|
||||
int begin = task_id * unit;
|
||||
int end = MSMIN(begin + unit, out_plane_size_);
|
||||
int channel = softmax_param_->input_shape_[softmax_param_->axis_];
|
||||
int offset = begin * channel;
|
||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->data_c());
|
||||
CHECK_NULL_RETURN(input_ptr);
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
|
||||
CHECK_NULL_RETURN(output_ptr);
|
||||
LogSoftmaxLastAxis(input_ptr + offset, output_ptr + offset, tmp_data_ + offset, end - begin, channel);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LogSoftmaxLastAxisRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LogSoftmaxCPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto ret = kernel->DoLogSoftmaxLastAxis(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLogSoftmaxLastAxis error task_id: " << task_id << ", ret: " << ret;
|
||||
|
@ -102,11 +113,10 @@ int LogSoftmaxCPUKernel::Run() {
|
|||
}
|
||||
} else {
|
||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->data_c());
|
||||
MS_ASSERT(input_ptr);
|
||||
CHECK_NULL_RETURN(input_ptr);
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
|
||||
MS_ASSERT(output_ptr);
|
||||
MS_ASSERT(tmp_data_);
|
||||
MS_ASSERT(softmax_param_);
|
||||
CHECK_NULL_RETURN(output_ptr);
|
||||
CHECK_NULL_RETURN(tmp_data_);
|
||||
LogSoftmax(input_ptr, output_ptr, tmp_data_, softmax_param_);
|
||||
}
|
||||
return ret;
|
||||
|
|
|
@ -27,11 +27,7 @@ class LogSoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
|
|||
LogSoftmaxCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx), tmp_data_(nullptr) {}
|
||||
~LogSoftmaxCPUKernel() override {
|
||||
if (tmp_data_ != nullptr) {
|
||||
free(tmp_data_);
|
||||
}
|
||||
};
|
||||
~LogSoftmaxCPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
|
|
@ -84,7 +84,6 @@ int LstmCPUKernel::InitInputWeightBias() {
|
|||
// weight -- row: hidden_size; col: input_size, need transpose
|
||||
// result -- row: seq_len * batch; col: hidden_size
|
||||
auto weight_i = in_tensors_.at(weight_i_index);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
weight_i_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
|
@ -92,6 +91,7 @@ int LstmCPUKernel::InitInputWeightBias() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
|
||||
CHECK_NULL_RETURN(weight_i_data);
|
||||
PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, lstm_param_->hidden_size_,
|
||||
lstm_param_->input_col_align_);
|
||||
|
||||
|
@ -102,8 +102,10 @@ int LstmCPUKernel::InitInputWeightBias() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float));
|
||||
PackLstmBias(input_bias_, reinterpret_cast<float *>(in_tensors_.at(bias_index)->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(bias_index)->data_c());
|
||||
CHECK_NULL_RETURN(bias_data);
|
||||
PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_,
|
||||
lstm_param_->bidirectional_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -113,8 +115,8 @@ int LstmCPUKernel::InitStateWeightBias() {
|
|||
// weight -- row: hidden_size; col: hidden_size, need transpose
|
||||
// result -- row: batch; col: hidden_size
|
||||
auto weight_h = in_tensors_.at(weight_h_index);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
|
||||
CHECK_NULL_RETURN(weight_h_data);
|
||||
if (!state_is_vec_) {
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
|
@ -151,6 +153,7 @@ int LstmCPUKernel::InitStateWeightBias() {
|
|||
memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float));
|
||||
auto state_bias =
|
||||
reinterpret_cast<float *>(in_tensors_.at(bias_index)->data_c()) + gate_num * lstm_param_->hidden_size_;
|
||||
CHECK_NULL_RETURN(state_bias);
|
||||
PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_,
|
||||
lstm_param_->bidirectional_);
|
||||
return RET_OK;
|
||||
|
@ -158,14 +161,12 @@ int LstmCPUKernel::InitStateWeightBias() {
|
|||
|
||||
int LstmCPUKernel::InitParam() {
|
||||
auto input = in_tensors_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
std::vector<int> in_shape = input->shape();
|
||||
lstm_param_->seq_len_ = in_shape.at(0);
|
||||
lstm_param_->batch_ = in_shape.at(1);
|
||||
lstm_param_->input_size_ = in_shape.at(2);
|
||||
|
||||
auto weight_i = in_tensors_.at(weight_i_index);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
lstm_param_->hidden_size_ = w_shape.at(1) / gate_num;
|
||||
|
||||
|
@ -190,6 +191,7 @@ int LstmCPUKernel::InitParam() {
|
|||
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_);
|
||||
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_);
|
||||
input_thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_));
|
||||
MS_CHECK_FALSE(input_thread_count_ == 0, RET_ERROR);
|
||||
input_thread_stride_ = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), input_thread_count_);
|
||||
|
||||
state_row_tile_ = row_tile_;
|
||||
|
@ -212,8 +214,15 @@ int LstmCPUKernel::InitParam() {
|
|||
}
|
||||
|
||||
int LstmCPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_6D);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 6);
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
CHECK_NULL_RETURN(in_tensors_.at(i));
|
||||
}
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 3);
|
||||
for (size_t i = 0; i < out_tensors_.size(); i++) {
|
||||
CHECK_NULL_RETURN(out_tensors_.at(i));
|
||||
}
|
||||
CHECK_NULL_RETURN(lstm_param_);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -245,7 +254,7 @@ int LstmCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LstmCPUKernel::MallocRunBuffer() {
|
||||
for (int i = 0; i < 6; i++) {
|
||||
for (int i = 0; i < 7; i++) {
|
||||
buffer_[i] = nullptr;
|
||||
}
|
||||
buffer_[packed_input_index] = reinterpret_cast<float *>(
|
||||
|
@ -313,13 +322,13 @@ int LstmCPUKernel::MallocRunBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InputWeightMatMul(int task_id) {
|
||||
void LstmCPUKernel::InputWeightMatMul(int task_id) {
|
||||
int current_start_oc = task_id * input_thread_stride_ * col_tile_;
|
||||
int current_rest_oc = 0;
|
||||
current_rest_oc = lstm_param_->hidden_size_ - current_start_oc;
|
||||
int cur_oc = MSMIN(input_thread_stride_ * col_tile_, current_rest_oc);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
return;
|
||||
}
|
||||
|
||||
auto input = buffer_[packed_input_index];
|
||||
|
@ -328,16 +337,12 @@ int LstmCPUKernel::InputWeightMatMul(int task_id) {
|
|||
auto bias = (bias_loop_ == nullptr) ? nullptr : bias_loop_ + current_start_oc;
|
||||
MatMulOpt(input, b, c, bias, ActType_No, lstm_param_->input_size_, lstm_param_->seq_len_ * lstm_param_->batch_,
|
||||
cur_oc, lstm_param_->hidden_size_, OutType_Nhwc);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmInputMulWeightRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LstmCPUKernel *>(cdata);
|
||||
auto ret = kernel->InputWeightMatMul(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InputWeightMatMul error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
kernel->InputWeightMatMul(task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -349,7 +354,10 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
|
|||
weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i;
|
||||
bias_loop_ = input_bias + lstm_param_->input_col_align_ * i;
|
||||
gate_loop_ = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i;
|
||||
ParallelLaunch(this->ms_context_, LstmInputMulWeightRun, this, input_thread_count_);
|
||||
auto ret = ParallelLaunch(this->ms_context_, LstmInputMulWeightRun, this, input_thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
float *input_gate = gate;
|
||||
|
@ -382,7 +390,12 @@ int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden
|
|||
// buffer_[packed_input_index] : store packed input
|
||||
PackLstmInput(input, buffer_[packed_input_index], lstm_param_->seq_len_ * lstm_param_->batch_,
|
||||
lstm_param_->input_size_);
|
||||
LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, cell_state, false);
|
||||
auto ret =
|
||||
LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, cell_state, false);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm unidirectional calculation error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// backward
|
||||
if (lstm_param_->bidirectional_) {
|
||||
|
@ -396,45 +409,51 @@ int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden
|
|||
float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
|
||||
LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias, backward_state_bias,
|
||||
backward_hidden_state, backward_cell_state, true);
|
||||
ret = LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias,
|
||||
backward_state_bias, backward_hidden_state, backward_cell_state, true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm bidirectional calculation error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::Run() {
|
||||
auto input = in_tensors_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto hidden_state = in_tensors_.at(4);
|
||||
MS_ASSERT(hidden_state != nullptr);
|
||||
auto cell_state = in_tensors_.at(5);
|
||||
MS_ASSERT(cell_state != nullptr);
|
||||
auto input = in_tensors_.at(0);
|
||||
auto output = out_tensors_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto input_ptr = reinterpret_cast<float *>(input->data_c());
|
||||
MS_ASSERT(input_ptr);
|
||||
CHECK_NULL_RETURN(input_ptr);
|
||||
auto output_ptr = reinterpret_cast<float *>(output->data_c());
|
||||
MS_ASSERT(output_ptr);
|
||||
CHECK_NULL_RETURN(output_ptr);
|
||||
|
||||
auto hidden_state = in_tensors_.at(4);
|
||||
CHECK_NULL_RETURN(hidden_state->data_c());
|
||||
auto cell_state = in_tensors_.at(5);
|
||||
CHECK_NULL_RETURN(cell_state->data_c());
|
||||
|
||||
auto output_hidden_state = out_tensors_[1];
|
||||
CHECK_NULL_RETURN(output_hidden_state->data_c());
|
||||
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
|
||||
auto output_cell_state = out_tensors_[2];
|
||||
CHECK_NULL_RETURN(output_cell_state->data_c());
|
||||
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float));
|
||||
|
||||
auto ret = MallocRunBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer error.";
|
||||
FreeRunBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(input_bias_);
|
||||
MS_ASSERT(state_bias_);
|
||||
InnerExecute(output_ptr, input_ptr, reinterpret_cast<float *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float *>(output_cell_state->data_c()));
|
||||
CHECK_NULL_RETURN(weight_h_ptr_);
|
||||
CHECK_NULL_RETURN(weight_i_ptr_);
|
||||
CHECK_NULL_RETURN(input_bias_);
|
||||
CHECK_NULL_RETURN(state_bias_);
|
||||
ret = InnerExecute(output_ptr, input_ptr, reinterpret_cast<float *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float *>(output_cell_state->data_c()));
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTM, LiteKernelCreator<LstmCPUKernel>)
|
||||
|
|
|
@ -36,7 +36,7 @@ class LstmCPUKernel : public InnerKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InputWeightMatMul(int task_id);
|
||||
void InputWeightMatMul(int task_id);
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
|
|
|
@ -30,6 +30,7 @@ namespace mindspore::kernel {
|
|||
int SpliceCPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 1);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(parameter_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/fp32/elu_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class EluGradCPUKernel : public InnerKernel {
|
||||
public:
|
||||
EluGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
~EluGradCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExcute(int task_id);
|
||||
|
||||
private:
|
||||
float alpha_ = 1.0; // currently MS supports only alpha = 1.0
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_
|
|
@ -39,33 +39,46 @@ LayerNormInt8CPUKernel::~LayerNormInt8CPUKernel() {
|
|||
|
||||
int LayerNormInt8CPUKernel::SetQuantArgs() {
|
||||
lite::Tensor *input = in_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(input);
|
||||
lite::Tensor *output = out_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(output);
|
||||
|
||||
quant_param_ = reinterpret_cast<LayerNormQuantArg *>(malloc(sizeof(LayerNormQuantArg)));
|
||||
if (quant_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc LayerNormQuantArg for LayerNorm int8 op failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input->quant_params().size() < 1) {
|
||||
MS_LOG(ERROR) << "Get LayerNorm int8 op input tensor quant params error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->in_zp_ = input->quant_params().front().zeroPoint;
|
||||
quant_param_->in_scale_ = input->quant_params().front().scale;
|
||||
|
||||
if (output->quant_params().size() < 1) {
|
||||
MS_LOG(ERROR) << "Get LayerNorm int8 op output tensor quant params error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->out_zp_ = output->quant_params().front().zeroPoint;
|
||||
quant_param_->out_scale_ = output->quant_params().front().scale;
|
||||
|
||||
lite::Tensor *gamma_tensor = in_tensors_.at(1);
|
||||
lite::Tensor *beta_tensor = in_tensors_.at(2);
|
||||
|
||||
CHECK_NULL_RETURN(gamma_tensor);
|
||||
if (gamma_tensor->quant_params().size() < 1) {
|
||||
MS_LOG(ERROR) << "LayerNorm int8 op gamma tensor error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
double gamma_scale = gamma_tensor->quant_params().front().scale;
|
||||
int gamma_zp = gamma_tensor->quant_params().front().zeroPoint;
|
||||
gamma_ptr_ = reinterpret_cast<float *>(malloc(gamma_tensor->ElementsNum() * sizeof(float)));
|
||||
if (gamma_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gamma_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CHECK_NULL_RETURN(gamma_ptr_);
|
||||
int8_t *src_gamma = reinterpret_cast<int8_t *>(gamma_tensor->data_c());
|
||||
for (int i = 0; i < gamma_tensor->ElementsNum(); i++) {
|
||||
gamma_ptr_[i] = (src_gamma[i] - gamma_zp) * gamma_scale;
|
||||
}
|
||||
|
||||
lite::Tensor *beta_tensor = in_tensors_.at(2);
|
||||
CHECK_NULL_RETURN(beta_tensor);
|
||||
beta_ptr_ = reinterpret_cast<float *>(malloc(beta_tensor->ElementsNum() * sizeof(float)));
|
||||
if (beta_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc beta_ptr_ failed";
|
||||
|
@ -81,6 +94,9 @@ int LayerNormInt8CPUKernel::SetQuantArgs() {
|
|||
}
|
||||
|
||||
int LayerNormInt8CPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 3);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(param_);
|
||||
SetQuantArgs();
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
|
@ -127,6 +143,7 @@ int LayerNormInt8CPUKernel::DoExecute(int task_id) {
|
|||
|
||||
int LayerNormInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LayerNormInt8CPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
auto ret = kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LayerNormInt8Run task_id " << task_id << " failed.";
|
||||
|
@ -137,7 +154,9 @@ int LayerNormInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale)
|
|||
|
||||
int LayerNormInt8CPUKernel::Run() {
|
||||
src_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(src_ptr_);
|
||||
dst_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
|
||||
CHECK_NULL_RETURN(dst_ptr_);
|
||||
|
||||
auto ret = ParallelLaunch(this->ms_context_, LayerNormInt8Run, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -30,12 +30,13 @@ namespace mindspore::kernel {
|
|||
int FillOpenCLKernel::RunFill() {
|
||||
auto allocator_ = ocl_runtime_->GetAllocator();
|
||||
auto param = reinterpret_cast<FillParameter *>(this->op_parameter_);
|
||||
CHECK_NULL_RETURN(param);
|
||||
default_ = param->num_dims_;
|
||||
ImageSize img_size;
|
||||
cl_int4 fill_value = {};
|
||||
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
|
||||
auto src_data = out_tensors_[0]->data_c();
|
||||
MS_ASSERT(src_data);
|
||||
CHECK_NULL_RETURN(src_data);
|
||||
if (allocator_->GetImageSize(src_data, &img_size) != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetImageSize failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -53,12 +54,13 @@ int FillOpenCLKernel::RunFill() {
|
|||
|
||||
int FillOpenCLKernel::RunShape() {
|
||||
auto allocator_ = ocl_runtime_->GetAllocator();
|
||||
CHECK_NULL_RETURN(allocator_);
|
||||
auto src_data = out_tensors_[0]->data_c();
|
||||
MS_ASSERT(src_data);
|
||||
CHECK_NULL_RETURN(src_data);
|
||||
cl_int4 fill_value = {default_, default_, default_, default_};
|
||||
auto tensor_shape = in_tensors_[0]->shape();
|
||||
void *tensor_shape_data = tensor_shape.data();
|
||||
MS_ASSERT(tensor_shape_data);
|
||||
CHECK_NULL_RETURN(tensor_shape_data);
|
||||
for (int i = 0; i < tensor_shape.size(); ++i) {
|
||||
fill_value.s[i] = reinterpret_cast<int *>(tensor_shape_data)[i];
|
||||
}
|
||||
|
@ -84,12 +86,16 @@ int FillOpenCLKernel::CheckSpecs() {
|
|||
}
|
||||
auto param = this->op_parameter_;
|
||||
|
||||
if (out_tensors_[0]->shape().size() > OUTPUT_TENSOR_SIZE_4) {
|
||||
MS_LOG(ERROR) << " only support dim <= 4";
|
||||
auto input = in_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(input);
|
||||
if (input->shape().size() > DIMENSION_1D && param->type_ == PrimitiveType_Fill) {
|
||||
MS_LOG(ERROR) << " fill only support dim = 1";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[0]->shape().size() > DIMENSION_1D && param->type_ == PrimitiveType_Fill) {
|
||||
MS_LOG(ERROR) << " fill only support dim = 1";
|
||||
auto output = out_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(output);
|
||||
if (output->shape().size() > OUTPUT_TENSOR_SIZE_4) {
|
||||
MS_LOG(ERROR) << " only support dim <= 4";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -32,19 +32,24 @@ using mindspore::schema::PrimitiveType_LayerNormFusion;
|
|||
namespace mindspore::kernel {
|
||||
int LayerNormOpenCLKernel::CheckSpecs() {
|
||||
auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_);
|
||||
CHECK_NULL_RETURN(param);
|
||||
if (in_tensors_.size() != INPUT_TENSOR_SIZE_3 || out_tensors_.size() != OUTPUT_TENSOR_SIZE_1) {
|
||||
MS_LOG(ERROR) << "UnSupported in_tensors_.size: " << in_tensors_.size()
|
||||
<< " out_tensors_.size(): " << out_tensors_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_.at(0)->shape().size() != DIMENSION_4D) {
|
||||
MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size();
|
||||
auto *input = in_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(input);
|
||||
auto *output = out_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(output);
|
||||
if (input->shape().size() != DIMENSION_4D) {
|
||||
MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << input->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
normalized_axis_ = param->begin_params_axis_;
|
||||
epsilon_ = param->epsilon_;
|
||||
if (normalized_axis_ < 0) {
|
||||
normalized_axis_ += in_tensors_.at(0)->shape().size();
|
||||
normalized_axis_ += input->shape().size();
|
||||
}
|
||||
if (normalized_axis_ != 3) {
|
||||
MS_LOG(ERROR) << "UnSupported normalized_axis_ : " << param->normalized_dims_;
|
||||
|
@ -57,8 +62,10 @@ void LayerNormGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t
|
|||
const int max_divider = 8;
|
||||
const int max_x = 4, max_y = 8;
|
||||
int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x);
|
||||
MS_ASSERT(x);
|
||||
int yz = max_size / x;
|
||||
int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y);
|
||||
MS_ASSERT(y);
|
||||
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
|
||||
|
||||
local->clear();
|
||||
|
@ -117,8 +124,10 @@ void LayerNormOpenCLKernel::SetGlobalLocal() {
|
|||
|
||||
int LayerNormOpenCLKernel::Initweight() {
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
CHECK_NULL_RETURN(allocator);
|
||||
GpuTensorInfo img_info(in_tensors_.at(1));
|
||||
auto weight_tensor = in_tensors_.at(1);
|
||||
CHECK_NULL_RETURN(weight_tensor);
|
||||
size_t weight_size = img_info.Image2DSize;
|
||||
// allocated memory for weight and init value
|
||||
gamma_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
|
||||
|
@ -141,8 +150,9 @@ int LayerNormOpenCLKernel::Initweight() {
|
|||
}
|
||||
memset(gamma_, 0x01, weight_size);
|
||||
memset(beta_, 0x00, weight_size);
|
||||
MS_ASSERT(in_tensors_.at(1)->data_c());
|
||||
MS_ASSERT(in_tensors_.at(INPUT_TENSOR_SIZE_2)->data_c());
|
||||
CHECK_NULL_RETURN(in_tensors_.at(1)->data_c());
|
||||
CHECK_NULL_RETURN(in_tensors_.at(2));
|
||||
CHECK_NULL_RETURN(in_tensors_.at(2)->data_c());
|
||||
|
||||
if (weight_tensor->data_type() == kNumberTypeFloat16) {
|
||||
if (use_fp16_enable_) {
|
||||
|
|
|
@ -94,6 +94,7 @@ int ReduceOpenCLKernel::SetAxes() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
// copy axes from tensor to private var
|
||||
CHECK_NULL_RETURN(axes_tensor->data_c());
|
||||
for (int i = 0; i < std::min(num_axes, MAX_SHAPE_SIZE); ++i) {
|
||||
axes_[i] = reinterpret_cast<int *>(axes_tensor->data_c())[i];
|
||||
}
|
||||
|
@ -128,12 +129,15 @@ int ReduceOpenCLKernel::CheckSpecs() {
|
|||
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[0]->shape()[0] > DIMENSION_1D) {
|
||||
auto input = in_tensors_.at(0);
|
||||
CHECK_NULL_RETURN(input);
|
||||
if (input->shape()[0] > DIMENSION_1D) {
|
||||
MS_LOG(ERROR) << "reduce op only support n = 1";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
inShape = GpuTensorInfo(in_tensors_[0]);
|
||||
auto reduce_param = reinterpret_cast<ReduceParameter *>(op_parameter_);
|
||||
CHECK_NULL_RETURN(reduce_param);
|
||||
if (GetReduceTypeStr(reduce_param->mode_).empty()) {
|
||||
MS_LOG(ERROR) << "not supported reduce type:" << reduce_param->mode_;
|
||||
return RET_PARAM_INVALID;
|
||||
|
|
|
@ -1,81 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "src/runtime/kernel/arm/fp32/elu_fp32.h"
|
||||
#include "nnacl/fp32/elu_fp32.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::lite::Tensor;
|
||||
|
||||
class TestEluFp32 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestEluFp32() {}
|
||||
};
|
||||
|
||||
void EluTestInit(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_, EluParameter *elu_param) {
|
||||
Tensor *in_t_first = new Tensor(kNumberTypeFloat32, {6, 2}, mindspore::NHWC, lite::Tensor::Category::CONST_TENSOR);
|
||||
in_t_first->MallocData();
|
||||
float in_first[] = {-1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 0};
|
||||
memcpy(in_t_first->MutableData(), in_first, sizeof(float) * in_t_first->ElementsNum());
|
||||
inputs_->push_back(in_t_first);
|
||||
|
||||
Tensor *outputs_t = new Tensor(kNumberTypeFloat32, {6, 2}, mindspore::NHWC, lite::Tensor::Category::CONST_TENSOR);
|
||||
outputs_t->MallocData();
|
||||
outputs_->push_back(outputs_t);
|
||||
|
||||
elu_param->alpha_ = 2.0;
|
||||
}
|
||||
|
||||
TEST_F(TestEluFp32, EluTest) {
|
||||
std::vector<Tensor *> inputs_;
|
||||
std::vector<Tensor *> outputs_;
|
||||
auto elu_param_ = new EluParameter();
|
||||
EluTestInit(&inputs_, &outputs_, elu_param_);
|
||||
|
||||
lite::InnerContext *ctx = new lite::InnerContext;
|
||||
ctx->thread_num_ = 2;
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
kernel::EluCPUKernel *elu =
|
||||
new kernel::EluCPUKernel(reinterpret_cast<OpParameter *>(elu_param_), inputs_, outputs_, ctx);
|
||||
|
||||
elu->Init();
|
||||
elu->Run();
|
||||
|
||||
std::cout << "output shape:" << std::endl;
|
||||
for (unsigned int i = 0; i < outputs_.front()->shape().size(); ++i) {
|
||||
std::cout << outputs_.front()->shape()[i] << ' ';
|
||||
}
|
||||
std::cout << std::endl;
|
||||
float *out = reinterpret_cast<float *>(outputs_.front()->MutableData());
|
||||
for (int i = 0; i < outputs_.front()->ElementsNum(); ++i) {
|
||||
std::cout << out[i] << ' ';
|
||||
}
|
||||
std::cout << std::endl;
|
||||
delete ctx;
|
||||
for (unsigned int i = 0; i < inputs_.size(); i++) {
|
||||
delete inputs_[i];
|
||||
}
|
||||
for (unsigned int i = 0; i < outputs_.size(); i++) {
|
||||
delete outputs_[i];
|
||||
}
|
||||
delete elu;
|
||||
}
|
||||
|
||||
}; // namespace mindspore
|
|
@ -61,9 +61,25 @@ ops::PrimitiveC *CaffeTanhParser::Parse(const caffe::LayerParameter &proto, cons
|
|||
return prim.release();
|
||||
}
|
||||
|
||||
ops::PrimitiveC *CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
|
||||
auto prim = std::make_unique<ops::Activation>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
prim->set_activation_type(mindspore::ActivationType::ELU);
|
||||
|
||||
if (proto.has_elu_param()) {
|
||||
const caffe::ELUParameter &eluParameter = proto.elu_param();
|
||||
if (eluParameter.has_alpha()) {
|
||||
prim->set_alpha(eluParameter.alpha());
|
||||
}
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
CaffeNodeRegistrar g_caffeReluParser("ReLU", new CaffeReluParser());
|
||||
CaffeNodeRegistrar g_caffeRelu6Parser("ReLU6", new CaffeRelu6Parser());
|
||||
CaffeNodeRegistrar g_caffeSigmoidParser("Sigmoid", new CaffeSigmoidParser());
|
||||
CaffeNodeRegistrar g_caffeTanhParser("TanH", new CaffeTanhParser());
|
||||
CaffeNodeRegistrar g_caffeEluParser("Elu", new CaffeEluParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,6 +54,14 @@ class CaffeTanhParser : public CaffeNodeParser {
|
|||
|
||||
ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override;
|
||||
};
|
||||
|
||||
class CaffeEluParser : public CaffeNodeParser {
|
||||
public:
|
||||
CaffeEluParser() : CaffeNodeParser("elu") {}
|
||||
~CaffeEluParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/caffe/caffe_elu_parser.h"
|
||||
#include <memory>
|
||||
#include "ops/elu.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) {
|
||||
auto prim = std::make_unique<ops::Elu>();
|
||||
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
if (proto.has_elu_param()) {
|
||||
const caffe::ELUParameter &eluParameter = proto.elu_param();
|
||||
if (eluParameter.has_alpha()) {
|
||||
prim->set_alpha(eluParameter.alpha());
|
||||
}
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,36 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser.h"
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class CaffeEluParser : public CaffeNodeParser {
|
||||
public:
|
||||
CaffeEluParser() : CaffeNodeParser("elu") {}
|
||||
~CaffeEluParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
|
|
@ -104,8 +104,9 @@ ops::PrimitiveC *OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, cons
|
|||
}
|
||||
|
||||
ops::PrimitiveC *OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
auto prim = std::make_unique<ops::Elu>();
|
||||
auto prim = std::make_unique<ops::Activation>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
prim->set_activation_type(mindspore::ActivationType::ELU);
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "alpha") {
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
|
@ -37,7 +36,9 @@ converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fm
|
|||
if (iter_level1 == node_parser_room.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_FALSE(node_type.empty(), nullptr);
|
||||
if (node_type.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto iter_level2 = iter_level1->second.find(node_type);
|
||||
if (iter_level2 == iter_level1->second.end()) {
|
||||
return nullptr;
|
||||
|
|
Loading…
Reference in New Issue