code review

This commit is contained in:
sunsuodong 2021-12-23 20:02:36 -08:00
parent 88930656e7
commit 8d9f661785
24 changed files with 112 additions and 98 deletions

View File

@ -169,9 +169,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_f
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_avx512_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32
<<<<<<< HEAD
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::Run
=======
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::GetWeights
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
>>>>>>> Updating the redistribution cost in D-Rec cost model
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16

View File

@ -398,6 +398,8 @@ void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *i
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int task_id) {
NNACL_CHECK_ZERO_RETURN(conv_param->stride_w_);
NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_);
NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
@ -484,6 +486,8 @@ void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float1
void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding) {
NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_);
NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_);
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_h = dst + top * sliding->out_h_step_;
@ -644,6 +648,8 @@ void DeconvDepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const
void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int top, int bottom,
int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding) {
NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_);
NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_);
const float16_t *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;

View File

@ -21,6 +21,7 @@
#include "nnacl/crop_parameter.h"
void Fp16Crop(const float16_t *input, float16_t *output, int task_id, const CropParameter *para) {
NNACL_CHECK_ZERO_RETURN(para->thread_count_);
int input_dim = para->input_dim_;
switch (input_dim) {
case 1:

View File

@ -43,6 +43,9 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM;
int dst_kw_stride = conv_param->dilation_w_ * C8NUM;
NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_h_);
NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_w_);
for (int c = 0; c < oc8; c += 8) {
float16_t *dst_ptr = tmp + c * output_plane;
const float16_t *src_ptr = src + c * in_plane16 * kernel_plane;
@ -88,10 +91,10 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
dst_kw_index[i] += src_kw_index[i];
}
#endif
} /*kw*/
} /*kh*/
} /*iw*/
} /*ih*/
} // kw
} // kh
} // iw
} // ih
/* add bias for current oh*ow*C8
* write to output data ptr in nhwc format */
@ -105,6 +108,6 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
vst1q_f16(pack_tmp_data, data_v);
pack_tmp_data += C8NUM;
}
} /*oc8*/
} // oc8
return NNACL_OK;
}

View File

@ -344,9 +344,7 @@ int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit,
void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index,
int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) {
if (deconv_param->in_tile_w_count_ == 0) {
return;
}
NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_);
/* pack tile input */
int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE;
float16x4_t zero = vdup_n_f16(0.0f);
@ -411,6 +409,7 @@ void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *t
void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param,
const DeConvParam *deconv_param, int calculate_count, int tile_index) {
NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_);
/* merge */
int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE;

View File

@ -22,8 +22,10 @@ int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float
const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) {
NNACL_CHECK_NULL_RETURN_ERR(src_data);
NNACL_CHECK_NULL_RETURN_ERR(dst_data);
NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_);
int channel = param->channel_;
int hw_plane = param->inner_size_;
NNACL_CHECK_ZERO_RETURN_ERR(hw_plane);
int channel_step = UP_DIV(channel, param->op_parameter_.thread_num_);
int channel_begin = task_id * channel_step;
int channel_end = MSMIN(channel_begin + channel_step, channel);
@ -86,8 +88,10 @@ int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const
const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) {
NNACL_CHECK_NULL_RETURN_ERR(src_data);
NNACL_CHECK_NULL_RETURN_ERR(dst_data);
NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_);
int channel = param->channel_;
int hw_plane = param->inner_size_;
NNACL_CHECK_ZERO_RETURN_ERR(hw_plane);
int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM;
int channel_begin = (int)(task_id)*channel_step;
int channel_end = MSMIN(channel_begin + channel_step, channel);

View File

@ -72,6 +72,7 @@ int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const
}
NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_);
NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_);
NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_);
int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_);
int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_);
for (int i = task_id * step; i < thread_end; i++) {

View File

@ -45,6 +45,14 @@
} \
} while (0)
#define CHECK_NULL_RETURN_VOID(ptr) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << #ptr << " must not be null!"; \
return; \
} \
} while (0)
#define CHECK_LESS_RETURN(size1, size2) \
do { \
if ((size1) < (size2)) { \
@ -55,6 +63,7 @@
#else
#define CHECK_NULL_RETURN(ptr)
#define CHECK_NULL_RETURN_VOID(ptr)
#define CHECK_LESS_RETURN(size1, size2)
#endif
#endif // MINDSPORE_LITE_SRC_COMMON_LOG_UTIL_H_

View File

@ -294,6 +294,7 @@ int Convolution1x1FP16CPUKernel::Run() {
}
if (RepackWeight() != RET_OK) {
MS_LOG(ERROR) << "Repack weight failed.";
ctx_->allocator->Free(pack_input_);
return RET_ERROR;
}

View File

@ -30,7 +30,7 @@ void ConvolutionDepthwise3x3Fp16CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
int channel = weight_tensor->Batch();
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
MS_ASSERT(origin_weight != nullptr);
CHECK_NULL_RETURN_VOID(origin_weight);
PackWeightConvDw3x3Fp16(reinterpret_cast<float16_t *>(origin_weight), reinterpret_cast<float16_t *>(packed_weight_),
channel);
}

View File

@ -26,7 +26,7 @@ namespace mindspore::kernel {
void ConvolutionDepthwiseFp16CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
MS_ASSERT(origin_weight != nullptr);
CHECK_NULL_RETURN_VOID(origin_weight);
PackNCHWToNHWCFp16(reinterpret_cast<float16_t *>(origin_weight), reinterpret_cast<float16_t *>(packed_weight_), 1,
weight_tensor->Height() * weight_tensor->Width(), weight_tensor->Batch(), 0, 0);
}

View File

@ -57,7 +57,7 @@ int ConvolutionDepthwiseSWFp16CPUKernel::InitPackedInputOutput() {
void ConvolutionDepthwiseSWFp16CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
MS_ASSERT(origin_weight != nullptr);
NNACL_CHECK_NULL_RETURN_VOID(origin_weight);
PackNCHWFp16ToNC8HW8Fp16(reinterpret_cast<float16_t *>(origin_weight), reinterpret_cast<float16_t *>(packed_weight_),
1, weight_tensor->Height() * weight_tensor->Width(), weight_tensor->Batch());
}
@ -171,6 +171,7 @@ int ConvolutionDepthwiseSWFp16CPUKernel::Run() {
}
if (RepackWeight() != RET_OK) {
MS_LOG(ERROR) << "Repack weight failed.";
FreePackedInputOutput();
return RET_ERROR;
}
ret = ParallelLaunch(this->ms_context_, ConvDwSWFp16Run, this, conv_param_->thread_num_);

View File

@ -33,7 +33,7 @@ void ConvolutionFP16CPUKernel::PackWeight() {
int out_channel = filter_tensor->Batch();
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
void *weight_origin = (op_parameter_->is_train_session_) ? filter_tensor->data() : origin_weight_;
MS_ASSERT(weight_origin != nullptr);
CHECK_NULL_RETURN_VOID(weight_origin);
RowMajor2Col8MajorFp16(weight_origin, reinterpret_cast<float16_t *>(packed_weight_), out_channel,
in_channel * kernel_plane, false);
}
@ -178,6 +178,7 @@ int ConvolutionFP16CPUKernel::Run() {
}
if (RepackWeight() != RET_OK) {
MS_LOG(ERROR) << "Repack weight failed.";
FreeTmpBuffer();
return RET_ERROR;
}
ret = ParallelLaunch(this->ms_context_, ConvolutionFp16Impl, this, thread_count_);

View File

@ -242,6 +242,7 @@ int ConvolutionWinogradFP16CPUKernel::Run() {
}
if (RepackWeight() != RET_OK) {
MS_LOG(ERROR) << "Repack weight failed.";
FreeTmpBuffer();
return RET_ERROR;
}
ret = ParallelLaunch(this->ms_context_, ConvolutionWinogradFp16Impl, this, thread_count_);

View File

@ -85,6 +85,8 @@ int DeconvolutionDepthwiseFp16CPUKernel::MallocWeightBiasData() {
bias_data_ = malloc(C8NUM * OC8 * sizeof(float16_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
free(packed_weight_);
packed_weight_ = nullptr;
return RET_ERROR;
}
memset(bias_data_, 0, C8NUM * OC8 * sizeof(float16_t));
@ -95,7 +97,7 @@ int DeconvolutionDepthwiseFp16CPUKernel::MallocWeightBiasData() {
void DeconvolutionDepthwiseFp16CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
MS_ASSERT(origin_weight != nullptr);
NNACL_CHECK_NULL_RETURN_VOID(origin_weight);
PackNCHWFp16ToNC8HW8Fp16(reinterpret_cast<float16_t *>(origin_weight), reinterpret_cast<float16_t *>(packed_weight_),
1, weight_tensor->Height() * weight_tensor->Width(), weight_tensor->Batch());
}
@ -171,6 +173,13 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() {
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto output_tensor = out_tensors_.at(kOutputIndex);
auto *input_ptr = reinterpret_cast<float16_t *>(input_tensor->data());
auto *output_ptr = reinterpret_cast<float16_t *>(output_tensor->data());
CHECK_NULL_RETURN(input_ptr);
CHECK_NULL_RETURN(output_ptr);
auto ret = InitPackedInputOutput();
if (ret != 0) {
MS_LOG(ERROR) << "Deconvolution depthwise fp16 InitPackedInputOutput failed.";
@ -179,16 +188,10 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() {
}
if (RepackWeight() != RET_OK) {
MS_LOG(ERROR) << "Repack weight failed.";
FreePackedInputOutput();
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto output_tensor = out_tensors_.at(kOutputIndex);
auto *input_ptr = reinterpret_cast<float16_t *>(input_tensor->data());
auto *output_ptr = reinterpret_cast<float16_t *>(output_tensor->data());
CHECK_NULL_RETURN(input_ptr);
CHECK_NULL_RETURN(output_ptr);
if (need_align_) {
PackNHWCToNHWC8Fp16(input_ptr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);

View File

@ -60,7 +60,7 @@ void DeConvolutionFp16CPUKernel::PackWeight() {
auto kernel_h = weight_tensor->Height();
auto kernel_w = weight_tensor->Width();
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
MS_ASSERT(origin_weight != nullptr);
CHECK_NULL_RETURN_VOID(origin_weight);
PackNHWCFp16ToC8HWN8Fp16(reinterpret_cast<float16_t *>(origin_weight), reinterpret_cast<float16_t *>(packed_weight_),
input_channel, kernel_w * kernel_h, output_channel);
}

View File

@ -24,6 +24,8 @@ using mindspore::schema::PrimitiveType_ExpFusion;
namespace mindspore::kernel {
int ExpFp16CPUKernel::DoExcute(int task_id) {
CHECK_NULL_RETURN(input_addr_);
CHECK_NULL_RETURN(output_addr_);
ExpFusionFp16(reinterpret_cast<float16_t *>(input_addr_), reinterpret_cast<float16_t *>(output_addr_), param_,
task_id);
return RET_OK;

View File

@ -29,34 +29,15 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
namespace {
constexpr int kSecondInput = 2;
}
GatherFp16CPUKernel::~GatherFp16CPUKernel() {
if (input_data_) {
ms_context_->allocator->Free(input_data_);
input_data_ = nullptr;
}
}
int GatherFp16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 3);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
auto input_tensor = in_tensors_.at(0);
CHECK_NULL_RETURN(input_tensor);
if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data() != nullptr) {
const_input_ = true;
input_data_ =
reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data()), input_data_, input_tensor->ElementsNum());
}
CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ =
*(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
CHECK_NULL_RETURN(in_tensors_[FIRST_INPUT]);
CHECK_NULL_RETURN(in_tensors_[SECOND_INPUT]);
CHECK_NULL_RETURN(in_tensors_[THIRD_INPUT]);
CHECK_NULL_RETURN(out_tensors_[kOutputIndex]);
CHECK_NULL_RETURN(in_tensors_[THIRD_INPUT]->data());
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(static_cast<int *>(in_tensors_[THIRD_INPUT]->data()));
if (!InferShapeDone()) {
return RET_OK;
}
@ -89,9 +70,7 @@ int GatherFp16CPUKernel::DoGather(int task_id) {
}
auto thread_stride = stride * task_id;
int8_t *int8_in = nullptr;
if (input_tensor->data_type() == kNumberTypeFloat32) {
int8_in = reinterpret_cast<int8_t *>(input_data_);
} else if (input_tensor->data_type() == kNumberTypeFloat16) {
if (input_tensor->data_type() == kNumberTypeFloat16) {
int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
} else {
MS_LOG(ERROR) << "input data type error";
@ -121,10 +100,6 @@ void GatherFp16CPUKernel::FreeIndicesData() {
ms_context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
if (!const_input_ && input_data_) {
ms_context_->allocator->Free(input_data_);
input_data_ = nullptr;
}
}
int GatherFp16CPUKernel::Run() {
@ -136,20 +111,6 @@ int GatherFp16CPUKernel::Run() {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
if (!const_input_) {
auto input_tensor = in_tensors_.at(0);
CHECK_NULL_RETURN(input_tensor->data());
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_data_ =
reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
FreeIndicesData();
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data()), input_data_, input_tensor->ElementsNum());
}
}
ret = ParallelLaunch(this->ms_context_, GatherRunFp16, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";

View File

@ -30,7 +30,7 @@ class GatherFp16CPUKernel : public InnerKernel {
GatherFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
~GatherFp16CPUKernel() override;
~GatherFp16CPUKernel() = default;
int Prepare() override;
int ReSize() override;
@ -41,8 +41,6 @@ class GatherFp16CPUKernel : public InnerKernel {
int *indices_data_ = nullptr;
int AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor);
void FreeIndicesData();
float16_t *input_data_ = nullptr;
bool const_input_ = false;
bool is_indices_int32_ = false;
};
} // namespace mindspore::kernel

View File

@ -146,8 +146,13 @@ int GroupConvolutionFP16CPUKernel::Prepare() {
MS_LOG(ERROR) << "GetSingleConv for fp16 group conv failed.";
return lite::RET_ERROR;
}
group_convs_.emplace_back(new (std::nothrow) ConvolutionDelegateFP16CPUKernel(
reinterpret_cast<OpParameter *>(new_conv_param), new_inputs, new_outputs, ctx_));
auto kernel = new (std::nothrow)
ConvolutionDelegateFP16CPUKernel(reinterpret_cast<OpParameter *>(new_conv_param), new_inputs, new_outputs, ctx_);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed.";
return lite::RET_ERROR;
}
group_convs_.push_back(kernel);
}
return GroupConvolutionBaseCPUKernel::Prepare();
}

View File

@ -21,6 +21,7 @@
#include "nnacl/fp16/gru_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/lstm_fp16.h"
#include "nnacl/errorcode.h"
using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
@ -68,9 +69,11 @@ int GruFp16CPUKernel::InitParam() {
auto weight_g = in_tensors_.at(1);
MS_ASSERT(weight_g != nullptr);
std::vector<int> w_shape = weight_g->shape();
NNACL_CHECK_ZERO_RETURN_ERR(gate_num);
gru_param_->hidden_size_ = w_shape.at(1) / gate_num;
weight_batch_ = gru_param_->bidirectional_ ? 2 * gate_num : gate_num;
gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_
constexpr int twice = 2;
weight_batch_ = gru_param_->bidirectional_ ? twice * gate_num : gate_num;
gru_param_->output_step_ = gru_param_->bidirectional_ ? twice * gru_param_->batch_ * gru_param_->hidden_size_
: gru_param_->batch_ * gru_param_->hidden_size_;
gru_param_->input_row_align_ = UP_ROUND(gru_param_->seq_len_ * gru_param_->batch_, C16NUM);
@ -189,8 +192,8 @@ int GruFp16CPUKernel::InitStateWeightBias() {
}
int GruFp16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 5);
CHECK_LESS_RETURN(out_tensors_.size(), 2);
CHECK_LESS_RETURN(in_tensors_.size(), C5NUM);
CHECK_LESS_RETURN(out_tensors_.size(), C2NUM);
if (!InferShapeDone()) {
return RET_OK;
}
@ -270,7 +273,7 @@ int GruFp16CPUKernel::Run() {
CHECK_NULL_RETURN(hidden_state->data());
memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float16_t));
int check_seq_len = gru_param_->seq_len_;
if (in_tensors_.size() == 6) {
if (in_tensors_.size() == C6NUM) {
MS_ASSERT(in_tensors_.at(5) != nullptr);
int *seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data());
MS_ASSERT(seq_len != nullptr);

View File

@ -22,6 +22,7 @@
#include "include/errorcode.h"
#include "nnacl/fp16/lstm_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/errorcode.h"
using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
@ -73,9 +74,11 @@ int LstmFp16CPUKernel::InitParam() {
auto weight_i = in_tensors_.at(1);
std::vector<int> w_shape = weight_i->shape();
NNACL_CHECK_ZERO_RETURN_ERR(gate_num);
lstm_param_->hidden_size_ = w_shape.at(1) / gate_num;
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
constexpr int twice = 2;
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? twice * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = lstm_param_->bidirectional_ ? 2 * gate_num : gate_num;
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM);

View File

@ -26,6 +26,8 @@ using mindspore::schema::PrimitiveType_HashtableLookup;
namespace mindspore::kernel {
int HashtableLookupCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C3NUM);
CHECK_LESS_RETURN(out_tensors_.size(), C2NUM);
if (!InferShapeDone()) {
return RET_OK;
}
@ -39,11 +41,16 @@ static int CmpKeyFunc(const void *lhs, const void *rhs) {
}
int HashtableLookupCPUKernel::Run() {
auto input_tensor = in_tensors_.at(0);
auto keys_tensor = in_tensors_.at(1);
auto values_tensor = in_tensors_.at(2);
auto output_tensor = out_tensors_.at(0);
auto hits_tensor = out_tensors_.at(1);
auto input_tensor = in_tensors_[FIRST_INPUT];
auto keys_tensor = in_tensors_[SECOND_INPUT];
auto values_tensor = in_tensors_[THIRD_INPUT];
auto output_tensor = out_tensors_[FIRST_INPUT];
auto hits_tensor = out_tensors_[SECOND_INPUT];
CHECK_NULL_RETURN(input_tensor);
CHECK_NULL_RETURN(keys_tensor);
CHECK_NULL_RETURN(values_tensor);
CHECK_NULL_RETURN(output_tensor);
CHECK_NULL_RETURN(hits_tensor);
int rows = GetStringCount(values_tensor);
if (rows < 0) {

View File

@ -31,6 +31,8 @@ constexpr int LABEL_INDEX = 2;
constexpr int WEIGHT_INDEX = 3;
} // namespace
int PredictCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C4NUM);
CHECK_LESS_RETURN(out_tensors_.size(), C2NUM);
if (!InferShapeDone()) {
return RET_OK;
}
@ -41,16 +43,19 @@ int PredictCPUKernel::ReSize() { return RET_OK; }
std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() {
std::vector<LabelInfo> label_info_vec;
auto input_tensor = in_tensors_.at(INPUT_INDEX);
auto keys_tensor = in_tensors_.at(KEY_INDEX);
auto labels_tensor = in_tensors_.at(LABEL_INDEX);
auto weights_tensor = in_tensors_.at(WEIGHT_INDEX);
auto input_tensor = in_tensors_[INPUT_INDEX];
auto keys_tensor = in_tensors_[KEY_INDEX];
auto labels_tensor = in_tensors_[LABEL_INDEX];
auto weights_tensor = in_tensors_[WEIGHT_INDEX];
if (input_tensor == nullptr || keys_tensor == nullptr || labels_tensor == nullptr || weights_tensor == nullptr) {
return label_info_vec;
}
int32_t *input = reinterpret_cast<int32_t *>(input_tensor->MutableData());
int32_t *key_begin = reinterpret_cast<int32_t *>(keys_tensor->MutableData());
int32_t *input = reinterpret_cast<int32_t *>(input_tensor->data());
int32_t *key_begin = reinterpret_cast<int32_t *>(keys_tensor->data());
int32_t *key_end = key_begin + keys_tensor->ElementsNum();
int32_t *labels = reinterpret_cast<int32_t *>(labels_tensor->MutableData());
float *weights = reinterpret_cast<float *>(weights_tensor->MutableData());
int32_t *labels = reinterpret_cast<int32_t *>(labels_tensor->data());
float *weights = reinterpret_cast<float *>(weights_tensor->data());
int32_t input_elements_num = input_tensor->ElementsNum();
int32_t items = labels_tensor->shape().at(1);
@ -82,10 +87,12 @@ int PredictCPUKernel::Run() {
std::vector<LabelInfo> label_info_vec = GetLabelInfo();
std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp);
auto output_label_tensor = out_tensors_.at(0);
auto output_weight_tensor = out_tensors_.at(1);
auto output_label = reinterpret_cast<int32_t *>(output_label_tensor->MutableData());
auto output_weight = reinterpret_cast<float *>(output_weight_tensor->MutableData());
auto output_label_tensor = out_tensors_[FIRST_INPUT];
auto output_weight_tensor = out_tensors_[SECOND_INPUT];
CHECK_NULL_RETURN(output_label_tensor);
CHECK_NULL_RETURN(output_weight_tensor);
auto output_label = reinterpret_cast<int32_t *>(output_label_tensor->data());
auto output_weight = reinterpret_cast<float *>(output_weight_tensor->data());
auto param = reinterpret_cast<PredictParameter *>(op_parameter_);
for (int i = 0; i < output_label_tensor->ElementsNum(); i++) {
if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) {