From: @zhaozhenlong
Reviewed-by: @jpc_chenjianping,@zhanghaibo5
Signed-off-by: @zhanghaibo5
This commit is contained in:
mindspore-ci-bot 2021-06-04 17:16:34 +08:00 committed by Gitee
commit 3ec4530ede
79 changed files with 362 additions and 272 deletions

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h"
#include <string.h>
#include <cstring>
#include <limits>
#include <vector>
#include "schema/model_generated.h"

View File

@ -144,9 +144,10 @@ int L2NormCPUKernel::Run() {
auto input_shape = in_tensors().at(kInputIndex)->shape();
input_ptr_ = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
int ret;
if (l2_norm_param_->axis_num_ == 0 || l2_norm_param_->axis_num_ == input_shape.size()) {
// all axis
auto ret = static_cast<const lite::InnerContext *>(this->context_)
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(SquareSumRun, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "L2Norm error: error_code[" << ret << "]";
@ -164,7 +165,7 @@ int L2NormCPUKernel::Run() {
return RET_ERROR;
}
} else if (l2_norm_param_->axis_num_ == 1 && l2_norm_param_->axis_[0] == static_cast<int>(input_shape.size()) - 1) {
auto ret = static_cast<const lite::InnerContext *>(this->context_)
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(L2NormTrailingAxisRun, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "L2Norm error: error_code[" << ret << "]";

View File

@ -28,7 +28,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LRN;
namespace mindspore::kernel {
int LocalResponseNormCPUKernel::Init() { return RET_OK; }
int LocalResponseNormCPUKernel::ReSize() { return RET_OK; }

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/lstm_fp32.h"
#include <float.h>
#include <cfloat>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"

View File

@ -196,9 +196,9 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
memcpy(b_pack_ptr_, src_ptr, params_->batch * params_->col_ * params_->deep_ * sizeof(float));
} else {
for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->col_;
const float *src_data = src_ptr + i * params_->deep_ * params_->col_;
float *dst = b_pack_ptr_ + i * params_->deep_ * params_->col_;
RowMajor2ColMajor(src, dst, params_->deep_, params_->col_);
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
}
}
return RET_OK;

View File

@ -85,7 +85,7 @@ void PadCPUKernel::InitMirrorPadBlock() {
int cur_input = 1;
int cur_output = 1;
for (size_t i = 0; i < COMM_SHAPE_SIZE; ++i) {
if (1 < cur_input) {
if (cur_input > 1) {
input_separate_dims.emplace_back(cur_input);
output_separate_dims.emplace_back(cur_output);
separate_offset.emplace_back(0);
@ -355,20 +355,21 @@ void PadCPUKernel::CalculateStrides() {
}
int PadCPUKernel::HandleMirrorPad() {
int ret;
if (in_tensors_.size() == 1) {
auto input_shape = in_tensors_.at(0)->shape();
int rank = static_cast<int>(input_shape.size());
auto ret = ExtendShape(in_, COMM_SHAPE_SIZE, input_shape.data(), rank);
ret = ExtendShape(in_, COMM_SHAPE_SIZE, input_shape.data(), rank);
if (ret != RET_OK) {
return ret;
}
} else {
auto ret = CopyPaddingFromInput();
ret = CopyPaddingFromInput();
if (ret != RET_OK) {
return ret;
}
}
auto ret = CheckPaddings(pad_param_->paddings_, COMM_SHAPE_SIZE, in_, pad_param_->pad_mode_);
ret = CheckPaddings(pad_param_->paddings_, COMM_SHAPE_SIZE, in_, pad_param_->pad_mode_);
if (ret != RET_OK) {
return ret;
}

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/pooling_fp32.h"
#include <float.h>
#include <cfloat>
#include "nnacl/fp32/pooling_fp32.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/reverse_fp32.h"
#include <string.h>
#include <cstring>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"

View File

@ -29,7 +29,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ROIPooling;
namespace mindspore::kernel {
int ROIPoolingCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/scale_fp32.h"
#include <string.h>
#include <cstring>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
#include <string.h>
#include <cstring>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -85,7 +85,6 @@ int ScatterNDCPUKernel::ReSize() {
return RET_ERROR;
}
}
// for (size_t i = 0; i < static_cast<size_t>(indice_unit_rank); i++) {}
// calculate unit_size_
unit_size_ = 1;

View File

@ -27,7 +27,6 @@ using mindspore::lite::StringPack;
using mindspore::schema::PrimitiveType_SkipGram;
namespace mindspore::kernel {
int SkipGramCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/softmax_fp32.h"
#include <string.h>
#include <cstring>
#include <vector>
#include "nnacl/fp32/softmax_fp32.h"
#include "schema/model_generated.h"

View File

@ -32,7 +32,6 @@ using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_SpaceToDepth;
namespace mindspore::kernel {
int SpaceToDepthCPUKernel::Init() {
SpaceToDepthParameter *param = reinterpret_cast<SpaceToDepthParameter *>(op_parameter_);
if (param->block_size_ <= 0) {

View File

@ -29,10 +29,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Adam;
namespace mindspore::kernel {
int AdamCPUKernel::ReSize() { return RET_OK; }
static int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, float beta2, float beta1_power,
static int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float beta1_power,
float beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) {
if ((1.f - beta1_power) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0 or below";

View File

@ -30,7 +30,7 @@ using mindspore::schema::PrimitiveType_ApplyMomentum;
namespace mindspore::kernel {
int ApplyMomentumCPUKernel::ReSize() { return RET_OK; }
static int DoApplyMomentum(float *weight, float *accumulate, float learning_rate, float *gradient, float moment,
static int DoApplyMomentum(float *weight, float *accumulate, float learning_rate, const float *gradient, float moment,
bool nesterov, int start, int end) {
if (nesterov) {
for (int i = start; i < end; i++) {

View File

@ -28,7 +28,6 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int ArithmeticGradCPUKernel::Init() {
auto dx1 = out_tensors_[0];
auto dx2 = out_tensors_[1];
@ -240,7 +239,7 @@ kernel::InnerKernel *CpuArithmeticGradFp32KernelCreator(const std::vector<lite::
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(nullptr != opParameter);
MS_ASSERT(opParameter != nullptr);
if (opParameter == nullptr) {
return nullptr;
}

View File

@ -28,7 +28,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Assign;
namespace mindspore::kernel {
int AssignCPUKernel::ReSize() { return RET_OK; }
int AssignCPUKernel::Execute(int task_id) {

View File

@ -28,7 +28,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BiasAddGrad;
namespace mindspore::kernel {
int BiasGradCPUKernel::ReSize() {
auto dims = in_tensors_[0]->shape();
bias_param->ndim_ = dims.size();
@ -37,7 +36,7 @@ int BiasGradCPUKernel::ReSize() {
bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W,
}
bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1];
for (auto i = bias_param->ndim_; i < 4; i++) {
for (auto i = bias_param->ndim_; i < DIMENSION_4D; i++) {
bias_param->in_shape0_[i] = 0;
bias_param->out_shape_[i] = 0;
}

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
#include <math.h>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
@ -35,10 +35,15 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BatchNormGrad;
namespace mindspore::kernel {
namespace {
constexpr int kWsMultiplier = 2;
constexpr int kMaxTaskNum = 4;
} // namespace
int BNGradCPUKernel::ReSize() {
auto *input_x = in_tensors_.at(1);
int channels = input_x->shape().at(kNHWC_C);
ws_size_ = 2 * channels;
ws_size_ = kWsMultiplier * channels;
set_workspace_size(ws_size_ * sizeof(float));
return RET_OK;
}
@ -85,7 +90,7 @@ int BNGradCPUKernel::Execute(int task_id) {
count = (count < 0) ? 0 : count;
switch (stage) {
case 0: {
for (int job = task_id; job < 4; job += thread_num) {
for (int job = task_id; job < kMaxTaskNum; job += thread_num) {
switch (job) {
case 0:
var2Invar(save_var, input_var->ElementsNum(), bn_param->epsilon_);
@ -134,8 +139,9 @@ int BNGradRun(void *cdata, int task_id) {
int BNGradCPUKernel::Run() {
stage_ = 0;
thread_num_ = context_->thread_num_;
int error_code;
if (thread_num_ == 1) {
int error_code = static_cast<const lite::InnerContext *>(this->context_)
error_code = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(BNGradRun, this, thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";
@ -145,7 +151,7 @@ int BNGradCPUKernel::Run() {
const std::vector<int> threads = {thread_num_, 1, thread_num_};
for (size_t stage = 0; stage < threads.size(); stage++) {
stage_ = static_cast<int>(stage);
int error_code = static_cast<const lite::InnerContext *>(this->context_)
error_code = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(BNGradRun, this, threads.at(stage));
if (error_code != RET_OK) {
MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]";

View File

@ -94,7 +94,7 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) {
const int k_h = conv_param_->kernel_h_;
const int k_w = conv_param_->kernel_w_;
const int batch = conv_param_->output_batch_;
const int out_ch = conv_param_->output_channel_; // out_y->shape()[3];
const int out_ch = conv_param_->output_channel_;
const int groups = conv_param_->group_;
const int out_h = conv_param_->output_h_;
const int out_w = conv_param_->output_w_;
@ -103,18 +103,22 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) {
const int k = k_h * k_w * in_ch / groups;
float *workspace_temp = static_cast<float *>(workspace());
float *mat_workspace = workspace_temp + ws_size_;
int real_chunk;
float *mat_a = nullptr;
float *im = nullptr;
const float *mat_b = nullptr;
float *mat_c = nullptr;
if (do_dw_) {
const int kernel_spatial = k_h * k_w;
for (int i = 0; i < batch; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = workspace_temp;
float *im = x_addr + (i * in_ch * in_h * in_w);
real_chunk = MSMIN(m - ci, chunk_);
mat_a = workspace_temp;
im = x_addr + (i * in_ch * in_h * in_w);
RollingIm2ColPackDwUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
for (int j = 0; j < groups; ++j) {
const float *mat_b = w_addr + j * nweights / groups;
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
mat_b = w_addr + j * nweights / groups;
mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a + (j * kernel_spatial), k * groups, mat_b, k, 0, mat_c, out_ch,
mat_workspace);
}
@ -124,24 +128,24 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < groups; ++j) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = workspace_temp;
const float *mat_b = w_addr + j * nweights / groups;
float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups);
real_chunk = MSMIN(m - ci, chunk_);
mat_a = workspace_temp;
mat_b = w_addr + j * nweights / groups;
mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch;
im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups);
RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci);
GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace);
}
}
}
} else {
const float *mat_b = w_addr;
mat_b = w_addr;
const size_t in_plane_size = in_ch * in_h * in_w;
for (int i = 0; i < batch; ++i) {
float *im = x_addr + i * in_plane_size;
im = x_addr + i * in_plane_size;
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_c = y_addr + i * n * m + ci * out_ch;
real_chunk = MSMIN(m - ci, chunk_);
mat_c = y_addr + i * n * m + ci * out_ch;
int input_height = ci / out_w * conv_param_->stride_h_;
int input_width = ci % out_w * conv_param_->stride_w_;
int offset = (input_height * in_w + input_width) * in_ch;

View File

@ -111,7 +111,11 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
count = (count < 0) ? 0 : count;
int start = stride * task_id;
int end = start + count;
int real_chunk;
float *mat_a = nullptr;
float *mat_b = nullptr;
float *mat_c = nullptr;
float *im = nullptr;
if (do_dw_) {
#ifdef ENABLE_ARM
stride = UP_DIV(k_h * k_w, thread_num);
@ -128,13 +132,13 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
const int kernel_spatial = k_h * k_w;
for (int i = 0; i < batch; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_b = workspace_temp + task_id * ws_size_;
float *im = x_addr + (i * in_ch * in_h * in_w);
real_chunk = MSMIN(m - ci, chunk_);
mat_b = workspace_temp + task_id * ws_size_;
im = x_addr + (i * in_ch * in_h * in_w);
RollingIm2ColPackDwUnitFp32(im, conv_param, mat_b, real_chunk, ci);
for (int j = start; j < end; ++j) {
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_c = dw_addr + j * nweights / groups;
mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
mat_c = dw_addr + j * nweights / groups;
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b + (j * kernel_spatial), n * groups, 1, mat_c, n,
mat_workspace);
}
@ -145,11 +149,11 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
for (int i = start; i < end; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
for (int j = 0; j < groups; ++j) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
float *mat_b = workspace_temp + task_id * ws_size_;
float *mat_c = dw_addr + j * nweights / groups;
float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups);
real_chunk = MSMIN(m - ci, chunk_);
mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch;
mat_b = workspace_temp + task_id * ws_size_;
mat_c = dw_addr + j * nweights / groups;
im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups);
RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci);
GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 0, mat_tmp, n, mat_workspace);
std::unique_lock<std::mutex> merge_lock(lock_);
@ -158,13 +162,13 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) {
}
}
} else {
float *mat_c = dw_addr;
mat_c = dw_addr;
const size_t in_plane_size = in_ch * in_h * in_w;
for (int i = start; i < end; ++i) {
for (int ci = 0; ci < m; ci += chunk_) {
int real_chunk = MSMIN(m - ci, chunk_);
float *mat_a = dy_addr + i * m * k + ci * out_ch;
float *im = x_addr + i * in_plane_size;
real_chunk = MSMIN(m - ci, chunk_);
mat_a = dy_addr + i * m * k + ci * out_ch;
im = x_addr + i * in_plane_size;
int input_h = ci / out_w * conv_param->stride_h_;
int input_w = ci % out_w * conv_param->stride_w_;
int offset = (input_h * in_w + input_w) * in_ch;

View File

@ -77,8 +77,8 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) {
int in_ch = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int k_h = conv_param->kernel_h_; // out_dw->shape()[1];
int k_w = conv_param->kernel_w_; // out_dw->shape()[2];
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int batch = conv_param->output_batch_;
int out_ch = conv_param->output_channel_;
int groups = conv_param->group_;

View File

@ -31,7 +31,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Dropout;
namespace mindspore::kernel {
int DropoutCPUKernel::Init() {
auto param = reinterpret_cast<DropoutParameter *>(op_parameter_);
if (param == nullptr) {

View File

@ -31,7 +31,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DropoutGrad;
namespace mindspore::kernel {
int DropoutGradCPUKernel::Init() {
auto param = reinterpret_cast<DropoutParameter *>(op_parameter_);
if (param == nullptr) {
@ -64,7 +63,6 @@ int DropoutGradCPUKernel::Execute(int task_id) {
auto length = in_tensors_.at(kInputIndex)->ElementsNum();
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
if (count > 0) {
int start = stride * task_id;
DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_);

View File

@ -68,7 +68,6 @@ int PoolingGradCPUKernel::Execute(int task_id) {
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
int stride = UP_DIV(pool_param->output_batch_, thread_num_);
int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id);
if (count > 0) {
int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_;
int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_;

View File

@ -29,7 +29,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ResizeGrad;
namespace mindspore::kernel {
float Scaling(size_t in_size, size_t out_size, bool align_corners) {
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size);

View File

@ -29,10 +29,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SGD;
namespace mindspore::kernel {
int SgdCPUKernel::ReSize() { return RET_OK; }
int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate, float dampening, float moment,
int DoSgd(float *weight, float *accumulate, const float *gradient, float learning_rate, float dampening, float moment,
bool nesterov, int start, int end) {
if (moment > 0.f) {
if (nesterov) {
@ -54,8 +53,8 @@ int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate
return RET_OK;
}
int DoSgdInit(float *weight, float *accumulate, float *gradient, float *stat, float learning_rate, float dampening,
float moment, bool nesterov, int start, int end) {
int DoSgdInit(float *weight, float *accumulate, float *gradient, float *stat, float learning_rate, float moment,
bool nesterov, int start, int end) {
std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start]));
if (nesterov) {
for (int i = start; i < end; ++i) {
@ -106,8 +105,7 @@ int SgdCPUKernel::ExecuteInit(int task_id) {
int end = start + count;
if (count > 0) {
DoSgdInit(weight, accumulate, gradient, stat, learning_rate, sgd_param_->dampening_, moment,
sgd_param_->use_nesterov_, start, end);
DoSgdInit(weight, accumulate, gradient, stat, learning_rate, moment, sgd_param_->use_nesterov_, start, end);
}
return RET_OK;
}
@ -192,8 +190,7 @@ int SgdCPUKernel::OptimizerStep() {
DoSgd(weight, accumulate, grad_sum_, learning_rate, sgd_param_->dampening_, moment, sgd_param_->use_nesterov_,
start, end);
} else {
DoSgdInit(weight, accumulate, grad_sum_, stat, learning_rate, sgd_param_->dampening_, moment,
sgd_param_->use_nesterov_, start, end);
DoSgdInit(weight, accumulate, grad_sum_, stat, learning_rate, moment, sgd_param_->use_nesterov_, start, end);
}
std::fill(grad_sum_, grad_sum_ + length, 0);
OptimizerKernel::OptimizerStep();

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.h"
#include <math.h>
#include <cmath>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
@ -26,7 +26,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SigmoidCrossEntropyWithLogits;
namespace mindspore::kernel {
int SigmoidCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; }
int SigmoidCrossEntropyWithLogitsCPUKernel::Execute(int task_id) {

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits_grad.h"
#include <math.h>
#include <cmath>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
@ -26,7 +26,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad;
namespace mindspore::kernel {
int SigmoidCrossEntropyWithLogitsGradCPUKernel::ReSize() { return RET_OK; }
int SigmoidCrossEntropyWithLogitsGradCPUKernel::Execute(int task_id) {

View File

@ -25,7 +25,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SmoothL1Loss;
namespace mindspore::kernel {
int SmoothL1LossCPUKernel::ReSize() { return RET_OK; }
int SmoothL1LossCPUKernel::Execute(int task_id) {

View File

@ -25,7 +25,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SmoothL1LossGrad;
namespace mindspore::kernel {
int SmoothL1LossGradCPUKernel::ReSize() { return RET_OK; }
int SmoothL1LossGradCPUKernel::Execute(int task_id) {

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include <math.h>
#include <cmath>
#include "src/kernel_registry.h"
#include "nnacl/softmax_parameter.h"
#include "nnacl/fp32/softmax_fp32.h"
@ -28,7 +28,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SoftmaxCrossEntropyWithLogits;
namespace mindspore::kernel {
int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { return ReSize(); }
void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads,

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/softmax_grad.h"
#include <string.h>
#include <cstring>
#include <vector>
#include "nnacl/fp32_grad/softmax_grad.h"
#include "schema/model_generated.h"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
* limitations under the License.
*/
#include <math.h>
#include <cmath>
#include "src/kernel_registry.h"
#include "nnacl/softmax_parameter.h"
#include "nnacl/fp32/softmax_fp32.h"
@ -29,7 +29,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
namespace mindspore::kernel {
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; }
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
@ -162,7 +161,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() {
param->number_of_classes_ = dims.at(1);
param->batch_size_ = dims.at(0);
for (unsigned int i = 0; i < dims.size(); i++) param->input_shape_[i] = dims.at(i);
if (2 != this->in_tensors_.size()) {
if (this->in_tensors_.size() != 2) {
MS_LOG(ERROR) << "sparse softmax entropy loss should have two inputs";
return RET_ERROR;
}

View File

@ -31,7 +31,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_StridedSliceGrad;
namespace mindspore::kernel {
int StridedSliceGradCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;

View File

@ -30,7 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_UnsortedSegmentSum;
namespace mindspore::kernel {
int UnsortedSegmentSumCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;

View File

@ -27,6 +27,10 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_AddFusion;
namespace mindspore::kernel {
namespace {
constexpr int kBaseShift = 20;
} // namespace
QuantizedAddCPUKernel::~QuantizedAddCPUKernel() {
if (para_ != nullptr) {
free(para_);
@ -52,7 +56,7 @@ int QuantizedAddCPUKernel::Init() {
const double in1_scale = input1->quant_params().front().scale;
const double out_scale = output->quant_params().front().scale;
para_->left_shift_ = 20;
para_->left_shift_ = kBaseShift;
const double twice_max_input_scale = 2 * std::max(in0_scale, in1_scale);
const double in0_multiplier = in0_scale / twice_max_input_scale;
const double in1_multiplier = in1_scale / twice_max_input_scale;

View File

@ -107,6 +107,7 @@ int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) {
auto output_data = reinterpret_cast<uint8_t *>(out_tensors_[0]->MutableData());
auto element_num = out_tensors_[0]->ElementsNum();
auto param = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
int error_code;
if (param->broadcasting_ && arithmetic_run_ != nullptr) {
MS_ASSERT(op_parameter_->thread_num_ != 0);
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
@ -115,14 +116,14 @@ int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) {
return RET_OK;
}
int error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id,
error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id,
output_data + stride * thread_id, count, &quant_args_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Arithmetic run fail! ret: " << error_code;
return error_code;
}
} else if (arithmetic_run_ != nullptr) {
int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num, &quant_args_);
error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num, &quant_args_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Arithmetic run fail!ret: " << error_code;
return error_code;

View File

@ -63,7 +63,7 @@ int BatchToSpaceInt8CPUKernel::Init() {
}
int BatchToSpaceInt8CPUKernel::ReSize() {
MS_ASSERT(in_tensors_.at(0)->shape().size() == 4);
MS_ASSERT(in_tensors_.at(0)->shape().size() == DIMENSION_4D);
return RET_OK;
}

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/batchnorm_int8.h"
#include <math.h>
#include <cmath>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

View File

@ -40,9 +40,9 @@ int ConcatInt8CPUKernel::Init() {
}
for (size_t i = 0; i < input_num; i++) {
auto *input_tensor = in_tensors_.at(i);
auto quant_args = input_tensor->quant_params();
concat_param_->quant_arg_.in_args_[i].scale_ = quant_args.front().scale;
concat_param_->quant_arg_.in_args_[i].zp_ = quant_args.front().zeroPoint;
auto in_quant_args = input_tensor->quant_params();
concat_param_->quant_arg_.in_args_[i].scale_ = in_quant_args.front().scale;
concat_param_->quant_arg_.in_args_[i].zp_ = in_quant_args.front().zeroPoint;
}
auto output_tensor = out_tensors_.at(kOutputIndex);

View File

@ -238,7 +238,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(bias_data_, 0, size * sizeof(int32_t));
if (in_tensors_.size() == 3) {
if (in_tensors_.size() == kInputSize2) {
memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(int32_t));
}
@ -270,7 +270,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() {
return RET_ERROR;
}
memset(bias_data_, 0, col2 * sizeof(int32_t));
if (in_tensors_.size() == 3) {
if (in_tensors_.size() == kInputSize2) {
memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(int32_t));
}

View File

@ -23,6 +23,9 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
namespace {
constexpr size_t kUnitBufferMultipler = 4 * 4;
} // namespace
int ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParameter *conv_param) {
auto input_channel = conv_param->input_channel_;
auto output_channel = conv_param->output_channel_;
@ -82,7 +85,7 @@ int Convolution3x3Int8CPUKernel::InitWeightBias() {
int iC8 = UP_DIV(input_channel, C8NUM);
int oC4 = UP_DIV(output_channel, C4NUM);
// init weight
size_t transformed_size = iC8 * C8NUM * oC4 * C4NUM * 16 * sizeof(int16_t);
size_t transformed_size = iC8 * C8NUM * oC4 * C4NUM * kUnitBufferMultipler * sizeof(int16_t);
transformed_filter_addr_ = reinterpret_cast<int16_t *>(malloc(transformed_size));
if (transformed_filter_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed.";
@ -136,14 +139,14 @@ int Convolution3x3Int8CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
size_t block_unit_buffer_size = thread_count_ * 4 * 4 * C8NUM * sizeof(int16_t);
size_t block_unit_buffer_size = thread_count_ * kUnitBufferMultipler * C8NUM * sizeof(int16_t);
block_unit_buffer_ = reinterpret_cast<int16_t *>(ctx_->allocator->Malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc block_unit_buffer_ failed.";
return RET_ERROR;
}
size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * 16 * oc4 * C4NUM * sizeof(int32_t);
size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * kUnitBufferMultipler * oc4 * C4NUM * sizeof(int32_t);
tmp_dst_buffer_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";

View File

@ -23,6 +23,10 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
namespace {
constexpr int kConvDepthwise3x3BufferSize = 64 * 10 * 10;
constexpr int kChannelUnit = 8;
} // namespace
ConvolutionDepthwise3x3Int8CPUKernel::~ConvolutionDepthwise3x3Int8CPUKernel() {
if (sliding_ != nullptr) {
delete sliding_;
@ -40,7 +44,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::InitWeightBias() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData());
int channel = weight_tensor->Batch();
if (channel % 8 != 0) {
if (channel % kChannelUnit != 0) {
MS_LOG(ERROR) << "ConvolutionDepthwise3x3Int8CPUKernel doesn't support channel " << channel;
return RET_ERROR;
}
@ -63,8 +67,8 @@ int ConvolutionDepthwise3x3Int8CPUKernel::InitWeightBias() {
if (filter_per_channel) {
for (int i = 0; i < weight_tensor->Height() * weight_tensor->Width(); i++) {
for (int c = 0; c < channel; c++) {
int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_;
packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - weight_zp);
int per_channel_weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_;
packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - per_channel_weight_zp);
}
}
} else {
@ -119,7 +123,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::ReSize() {
}
int ConvolutionDepthwise3x3Int8CPUKernel::Execute(int task_id) {
auto buffer = buffer_ + 64 * 10 * 10 * task_id;
auto buffer = buffer_ + kConvDepthwise3x3BufferSize * task_id;
ConvDw3x3Int8(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
sliding_, task_id);
return RET_OK;
@ -136,7 +140,7 @@ int ConvDw3x3Int8Run(void *cdata, int task_id) {
}
int ConvolutionDepthwise3x3Int8CPUKernel::InitBuffer() {
int buffer_size = 64 * 10 * 10 * conv_param_->thread_num_;
int buffer_size = kConvDepthwise3x3BufferSize * conv_param_->thread_num_;
buffer_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(buffer_size * sizeof(int8_t)));
if (buffer_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";

View File

@ -56,8 +56,8 @@ int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
if (filter_per_channel) {
for (int i = 0; i < weight_tensor->Height() * weight_tensor->Width(); i++) {
for (int c = 0; c < channel; c++) {
int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_;
packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - weight_zp);
int per_channel_weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_;
packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - per_channel_weight_zp);
}
}
} else {

View File

@ -38,7 +38,7 @@ ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() {
int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() {
// init weight, int8 -> int16
// o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1
// o, h, w, i -> o/8, h, w, i, 8; o equals to group, i equals to 1
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData());
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);

View File

@ -37,7 +37,7 @@ DeconvolutionDepthwiseInt8CPUKernel::~DeconvolutionDepthwiseInt8CPUKernel() {
int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
// init weight: int8 -> int16
// o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1
// o, h, w, i -> o/8, h, w, i, 8; o equals to group, i equals to 1
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData());
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
@ -87,7 +87,7 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitSlideParam() {
int DeconvolutionDepthwiseInt8CPUKernel::InitBuffer() {
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM *
UP_DIV(conv_param_->input_channel_, 4);
UP_DIV(conv_param_->input_channel_, C4NUM);
packed_input_ = reinterpret_cast<int16_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int16_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";

View File

@ -132,7 +132,7 @@ int DeConvInt8CPUKernel::InitBiasWeight() {
return RET_ERROR;
}
memset(bias_data_, 0, size);
if (in_tensors_.size() == 3) {
if (in_tensors_.size() == kInputSize2) {
memcpy(bias_data_, in_tensors_.at(0)->MutableData(), conv_param_->output_channel_ * sizeof(int32_t));
}

View File

@ -81,5 +81,4 @@ int DepthToSpaceInt8CPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DepthToSpace, LiteKernelCreator<DepthToSpaceInt8CPUKernel>)
} // namespace mindspore::kernel

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h"
#include <string.h>
#include <cstring>
#include <limits>
#include <vector>
#include "schema/model_generated.h"

View File

@ -30,7 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
int GatherInt8CPUKernel::Init() {
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
auto in_quant_args = in_tensors_.at(0)->quant_params();

View File

@ -262,8 +262,8 @@ int MatmulBaseInt8CPUKernel::InitTmpBuffer() {
}
int MatmulBaseInt8CPUKernel::InitBias() {
if (in_tensors_.size() == 3) {
auto bias_tensor = in_tensors_[2];
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C4NUM);
bias_ptr_ = reinterpret_cast<int *>(malloc(max_bias_data * sizeof(int)));
if (bias_ptr_ == nullptr) {

View File

@ -70,12 +70,12 @@ int MulInt8CPUKernel::Init() {
}
void MulInt8CPUKernel::CheckSameShapeSize(std::vector<int> in_tensor0_shape, std::vector<int> in_tensor1_shape) {
bool condition1 = in_tensor0_shape[0] == in_tensor1_shape[0];
bool condition2 = in_tensor0_shape[1] == 1;
bool condition3 = in_tensor0_shape[2] == 1;
bool condition4 = in_tensor0_shape[3] == in_tensor1_shape[3];
bool condition5 = in_tensor1_shape[1] == 1;
bool condition6 = in_tensor1_shape[2] == 1;
bool condition1 = in_tensor0_shape[kNHWC_N] == in_tensor1_shape[kNHWC_N];
bool condition2 = in_tensor0_shape[kNHWC_H] == 1;
bool condition3 = in_tensor0_shape[kNHWC_W] == 1;
bool condition4 = in_tensor0_shape[kNHWC_C] == in_tensor1_shape[kNHWC_C];
bool condition5 = in_tensor1_shape[kNHWC_H] == 1;
bool condition6 = in_tensor1_shape[kNHWC_W] == 1;
if (condition1 && condition2 && condition3 && condition4) {
fast_hw_broadcast_ = true;
} else if (condition1 && condition4 && condition5 && condition6) {
@ -91,11 +91,11 @@ void MulInt8CPUKernel::CheckIfFastImpl() {
if (in_tensor0->shape().size() == COMM_SHAPE_SIZE && in_tensor1->shape().size() == COMM_SHAPE_SIZE) {
CheckSameShapeSize(in_tensor0->shape(), in_tensor1->shape());
} else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == COMM_SHAPE_SIZE) {
if (in_tensor0->ElementsNum() == in_tensor1->shape()[3]) {
if (in_tensor0->ElementsNum() == in_tensor1->shape()[kNHWC_C]) {
fast_hw_broadcast_ = true;
}
} else if (in_tensor0->shape().size() == COMM_SHAPE_SIZE && in_tensor1->shape().size() == 1) {
if (in_tensor1->ElementsNum() == in_tensor0->shape()[3]) {
if (in_tensor1->ElementsNum() == in_tensor0->shape()[kNHWC_C]) {
fast_hw_broadcast_ = true;
input1_hw_broadcast_ = true;
}
@ -162,6 +162,7 @@ int MulInt8CPUKernel::Run() {
elements_num_ = out_tensors_.at(0)->ElementsNum();
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
int ret = RET_ERROR;
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (input0_data_ == nullptr) {
@ -176,14 +177,14 @@ int MulInt8CPUKernel::Run() {
}
TileDimensionsInt8(static_cast<int8_t *>(in_tensors_.at(0)->MutableData()),
static_cast<int8_t *>(in_tensors_.at(1)->MutableData()), input0_data_, input1_data_, tile_para);
auto ret = static_cast<const lite::InnerContext *>(this->context_)
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(MulInt8Run, this, thread_count_);
ctx_->allocator->Free(input0_data_);
ctx_->allocator->Free(input1_data_);
return ret;
}
auto ret = static_cast<const lite::InnerContext *>(this->context_)
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(MulInt8Run, this, thread_count_);
return ret;
}

View File

@ -15,23 +15,13 @@
*/
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"
#include <stdlib.h>
#include <cstdlib>
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride,
size_t peroc);
extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_t row8, size_t col8, size_t deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int *multiplier,
int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp);
#ifdef ENABLE_ARM64
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,

View File

@ -22,6 +22,14 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int *multiplier, int *left_shift,
int *right_shift, int row, int col, int stride, size_t peroc);
void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_t row8, size_t col8, size_t deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int *multiplier,
int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp);
#ifdef ENABLE_ARM64
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,

View File

@ -26,7 +26,6 @@ using mindspore::lite::RET_OK;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_PadFusion;
namespace mindspore::kernel {
namespace {
constexpr size_t kMirrorPadInputSize = 2;
}

View File

@ -240,7 +240,7 @@ int ReduceInt8CPUKernel::CalculateQuantArgs() {
// (quant_out - zp_out)*scale_out = sum((quant_in -zp)*scale_in) * (1/num) for each axis in axes
// quant_out = sum(quant_in-zp) * (scale_in/scale_out) * (1/num)
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceMean)) {
if (input->shape().size() == 4 && pattern_ == kernel::HW) {
if (input->shape().size() == DIMENSION_4D && pattern_ == kernel::HW) {
// special case, can use pattern
ReduceMean4DCalQuantParam();
pattern_impl_ = true;
@ -309,8 +309,8 @@ int ReduceInt8CPUKernel::CalculateQuantArgsReduceSumSquare() {
MS_LOG(ERROR) << "ReduceProd new QuantMultiplier failed.";
return RET_NULL_PTR;
}
double sumsquare_multiplier = quant_arg_.in_scale_ * quant_arg_.in_scale_ / quant_arg_.out_scale_;
QuantizeMultiplierSmallerThanOne(sumsquare_multiplier, &qm->multiplier_, &shift);
double last_sumsquare_multiplier = quant_arg_.in_scale_ * quant_arg_.in_scale_ / quant_arg_.out_scale_;
QuantizeMultiplierSmallerThanOne(last_sumsquare_multiplier, &qm->multiplier_, &shift);
qm->left_shift_ = shift < 0 ? -shift : 0;
qm->right_shift_ = shift > 0 ? shift : 0;
sum_square_multipliers_.push_back(qm);
@ -470,8 +470,9 @@ int ReduceInt8CPUKernel::Fast4DReduceMeanHWImpl() {
}
int ReduceInt8CPUKernel::Run() {
int ret;
if (!this->valid_shape_) {
auto ret = CalculateQuantArgs();
ret = CalculateQuantArgs();
if (ret != RET_OK) {
return ret;
}
@ -481,7 +482,7 @@ int ReduceInt8CPUKernel::Run() {
return Fast4DReduceMeanHWImpl();
}
auto ret = MallocTmpBuffer();
ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
@ -495,13 +496,14 @@ int ReduceInt8CPUKernel::Run() {
begin_src_data_[i] = static_cast<int32_t>(input_data[i]);
}
src_data_ = begin_src_data_;
int error_code = RET_ERROR;
for (size_t i = 0; i < data_buffers_.size(); ++i) {
GetQuantArgs(i);
dst_data_ = data_buffers_[i];
outer_size_ = outer_sizes_[i];
inner_size_ = inner_sizes_[i];
axis_size_ = axis_sizes_[i];
auto error_code = static_cast<const lite::InnerContext *>(this->context_)
error_code = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
FreeTmpBuffer();
@ -517,7 +519,7 @@ int ReduceInt8CPUKernel::Run() {
axis_size_ = axis_sizes_.back();
last_dst_data_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
is_last_axis_ = true;
auto error_code = static_cast<const lite::InnerContext *>(this->context_)
error_code = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(ReduceInt8Impl, this, context_->thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";

View File

@ -29,7 +29,6 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Reshape;
namespace mindspore::kernel {
int ReshapeInt8CPUKernel::Init() {
auto *input_tensor = in_tensors_.at(kInputIndex);
auto in_quant_args = input_tensor->quant_params();

View File

@ -33,6 +33,9 @@ using mindspore::lite::RET_OK;
using mindspore::lite::KernelRegistrar;
namespace mindspore::kernel {
namespace {
constexpr unsigned int OFFSET_BASE = 10;
} // namespace
void ResizeInt8CPUKernel::FreeResizeBiLinear() {
free(resize_quant_arg_.x_axis_index_);
free(resize_quant_arg_.x_axis_lower_);
@ -96,32 +99,32 @@ int ResizeInt8CPUKernel::Init() {
int ResizeInt8CPUKernel::InitResizeQuantArg() {
auto out_shape = out_tensors_.front()->shape();
resize_quant_arg_.x_axis_index_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(2) * sizeof(int32_t)));
resize_quant_arg_.x_axis_index_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(kNHWC_W) * sizeof(int32_t)));
if (resize_quant_arg_.x_axis_index_ == nullptr) {
MS_LOG(ERROR) << "malloc x axis index array failed.";
return RET_ERROR;
}
resize_quant_arg_.x_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(2) * sizeof(int32_t)));
resize_quant_arg_.x_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(kNHWC_W) * sizeof(int32_t)));
if (resize_quant_arg_.x_axis_lower_ == nullptr) {
MS_LOG(ERROR) << "malloc x_axis_lower_ array failed.";
return RET_ERROR;
}
resize_quant_arg_.x_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(2) * sizeof(int32_t)));
resize_quant_arg_.x_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(kNHWC_W) * sizeof(int32_t)));
if (resize_quant_arg_.x_axis_upper_ == nullptr) {
MS_LOG(ERROR) << "malloc x_axis_upper_ array failed.";
return RET_ERROR;
}
resize_quant_arg_.y_axis_index_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(1) * sizeof(int32_t)));
resize_quant_arg_.y_axis_index_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(kNHWC_H) * sizeof(int32_t)));
if (resize_quant_arg_.y_axis_index_ == nullptr) {
MS_LOG(ERROR) << "malloc y_axis_index_ array failed.";
return RET_ERROR;
}
resize_quant_arg_.y_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(1) * sizeof(int32_t)));
resize_quant_arg_.y_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(kNHWC_H) * sizeof(int32_t)));
if (resize_quant_arg_.y_axis_lower_ == nullptr) {
MS_LOG(ERROR) << "malloc y_axis_lower_ array failed.";
return RET_ERROR;
}
resize_quant_arg_.y_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(1) * sizeof(int32_t)));
resize_quant_arg_.y_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape.at(kNHWC_H) * sizeof(int32_t)));
if (resize_quant_arg_.y_axis_upper_ == nullptr) {
MS_LOG(ERROR) << "malloc y_axis_upper_ array failed.";
return RET_ERROR;
@ -136,14 +139,14 @@ int ResizeInt8CPUKernel::CalRatio() {
auto out_tensor = out_tensors_.front();
auto out_width = out_tensor->Width();
auto out_height = out_tensor->Height();
resize_quant_arg_.ratio_x_ = ((1 << 10) * in_width + out_width / 2) / out_width;
resize_quant_arg_.ratio_y_ = ((1 << 10) * in_height + out_height / 2) / out_height;
resize_quant_arg_.ratio_x_ = ((1 << OFFSET_BASE) * in_width + out_width / 2) / out_width;
resize_quant_arg_.ratio_y_ = ((1 << OFFSET_BASE) * in_height + out_height / 2) / out_height;
bool align_corners = coordinate_transform_mode_ == schema::CoordinateTransformMode_ALIGN_CORNERS;
if (align_corners && out_width > 1) {
resize_quant_arg_.ratio_x_ = ((1 << 10) * (in_width - 1) + (out_width - 1) / 2) / (out_width - 1);
resize_quant_arg_.ratio_x_ = ((1 << OFFSET_BASE) * (in_width - 1) + (out_width - 1) / 2) / (out_width - 1);
}
if (align_corners && out_height > 1) {
resize_quant_arg_.ratio_y_ = ((1 << 10) * (in_height - 1) + (out_height - 1) / 2) / (out_height - 1);
resize_quant_arg_.ratio_y_ = ((1 << OFFSET_BASE) * (in_height - 1) + (out_height - 1) / 2) / (out_height - 1);
}
return RET_OK;
}
@ -152,46 +155,48 @@ int ResizeInt8CPUKernel::CalInterpolationRange() {
for (int i = 0; i < out_tensors_.front()->Height(); ++i) {
int32_t scaled_index = i * resize_quant_arg_.ratio_y_;
resize_quant_arg_.y_axis_index_[i] = scaled_index;
resize_quant_arg_.y_axis_lower_[i] = std::max(scaled_index / (1 << 10), 0);
resize_quant_arg_.y_axis_upper_[i] = std::min(scaled_index / (1 << 10) + 1, in_tensors_.front()->Height() - 1);
resize_quant_arg_.y_axis_lower_[i] = std::max(scaled_index / (1 << OFFSET_BASE), 0);
resize_quant_arg_.y_axis_upper_[i] =
std::min(scaled_index / (1 << OFFSET_BASE) + 1, in_tensors_.front()->Height() - 1);
}
for (int i = 0; i < out_tensors_.front()->Width(); ++i) {
int32_t scaled_index = i * resize_quant_arg_.ratio_x_;
resize_quant_arg_.x_axis_index_[i] = scaled_index;
resize_quant_arg_.x_axis_lower_[i] = std::max(scaled_index / (1 << 10), 0);
resize_quant_arg_.x_axis_upper_[i] = std::min(scaled_index / (1 << 10) + 1, in_tensors_.front()->Width() - 1);
resize_quant_arg_.x_axis_lower_[i] = std::max(scaled_index / (1 << OFFSET_BASE), 0);
resize_quant_arg_.x_axis_upper_[i] =
std::min(scaled_index / (1 << OFFSET_BASE) + 1, in_tensors_.front()->Width() - 1);
}
return RET_OK;
}
int ResizeInt8CPUKernel::InitResizeFloatQuantArg() {
auto out_shape = out_tensors_.front()->shape();
resize_float_quant_arg_.x_axis_index_ = reinterpret_cast<float *>(malloc(out_shape[2] * sizeof(float)));
resize_float_quant_arg_.x_axis_index_ = reinterpret_cast<float *>(malloc(out_shape[kNHWC_W] * sizeof(float)));
if (resize_float_quant_arg_.x_axis_index_ == nullptr) {
MS_LOG(ERROR) << "malloc x axis index array failed.";
return RET_ERROR;
}
resize_float_quant_arg_.x_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape[2] * sizeof(int32_t)));
resize_float_quant_arg_.x_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape[kNHWC_W] * sizeof(int32_t)));
if (resize_float_quant_arg_.x_axis_lower_ == nullptr) {
MS_LOG(ERROR) << "malloc x_axis_lower_ array failed.";
return RET_ERROR;
}
resize_float_quant_arg_.x_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape[2] * sizeof(int32_t)));
resize_float_quant_arg_.x_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape[kNHWC_W] * sizeof(int32_t)));
if (resize_float_quant_arg_.x_axis_upper_ == nullptr) {
MS_LOG(ERROR) << "malloc x_axis_upper_ array failed.";
return RET_ERROR;
}
resize_float_quant_arg_.y_axis_index_ = reinterpret_cast<float *>(malloc(out_shape[1] * sizeof(float)));
resize_float_quant_arg_.y_axis_index_ = reinterpret_cast<float *>(malloc(out_shape[kNHWC_H] * sizeof(float)));
if (resize_float_quant_arg_.y_axis_index_ == nullptr) {
MS_LOG(ERROR) << "malloc y_axis_index_ array failed.";
return RET_ERROR;
}
resize_float_quant_arg_.y_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape[1] * sizeof(int32_t)));
resize_float_quant_arg_.y_axis_lower_ = reinterpret_cast<int32_t *>(malloc(out_shape[kNHWC_H] * sizeof(int32_t)));
if (resize_float_quant_arg_.y_axis_lower_ == nullptr) {
MS_LOG(ERROR) << "malloc y_axis_lower_ array failed.";
return RET_ERROR;
}
resize_float_quant_arg_.y_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape[1] * sizeof(int32_t)));
resize_float_quant_arg_.y_axis_upper_ = reinterpret_cast<int32_t *>(malloc(out_shape[kNHWC_H] * sizeof(int32_t)));
if (resize_float_quant_arg_.y_axis_upper_ == nullptr) {
MS_LOG(ERROR) << "malloc y_axis_upper_ array failed.";
return RET_ERROR;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ namespace mindspore::kernel {
namespace {
constexpr size_t kScaleInputsSize = 2;
constexpr size_t kScaleBiasInputsSize = 3;
constexpr int kOffsetIndex = 2;
} // namespace
ScaleInt8CPUKernel::~ScaleInt8CPUKernel() {
if (tile_para != nullptr) {
@ -68,16 +69,16 @@ int ScaleInt8CPUKernel::InitScaleOffset() {
}
scale_param_->const_offset_ = false;
if (in_tensors_.size() == 3) {
if (in_tensors_.size() == kScaleBiasInputsSize) {
has_bias_ = true;
auto offset_tensor = in_tensors_.at(2);
auto offset_tensor = in_tensors_.at(kOffsetIndex);
auto *offset_ptr = reinterpret_cast<int8_t *>(offset_tensor->data_c());
// offset may be const value ,can be processed in prepare stage
if (offset_ptr != nullptr) {
scale_param_->const_offset_ = true;
input2_data_ = offset_ptr;
// need broadcasting
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(2)->ElementsNum()) {
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(kOffsetIndex)->ElementsNum()) {
input2_data_ = reinterpret_cast<int8_t *>(malloc(out_tensors_.at(0)->Size()));
if (input2_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input2_data_ failed.";
@ -88,7 +89,7 @@ int ScaleInt8CPUKernel::InitScaleOffset() {
return RET_ERROR;
}
malloced_offset_ = true;
TileOneDimensionInt8(reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c()),
TileOneDimensionInt8(reinterpret_cast<int8_t *>(in_tensors_.at(kOffsetIndex)->data_c()),
reinterpret_cast<int8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
}
@ -187,7 +188,7 @@ int ScaleInt8CPUKernel::InitQuantArgs() {
scale_param_->scale_mul_arg_.right_shift_ = shift < 0 ? -shift : 0;
if (in_tensors_.size() == kScaleBiasInputsSize) {
auto offset = in_tensors_.at(2);
auto offset = in_tensors_.at(kOffsetIndex);
auto offset_scale = offset->quant_params().front().scale;
scale_param_->offset_zp_ = offset->quant_params().front().zeroPoint;
@ -290,7 +291,7 @@ int ScaleInt8CPUKernel::Run() {
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
input0_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data_c());
output_data_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
int ret = RET_ERROR;
// need broadcasting
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
// scale is passed by previous node, need do broadcasting online
@ -314,12 +315,12 @@ int ScaleInt8CPUKernel::Run() {
input1_data_ = nullptr;
return RET_ERROR;
}
TileOneDimensionInt8(reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c()),
TileOneDimensionInt8(reinterpret_cast<int8_t *>(in_tensors_.at(kOffsetIndex)->data_c()),
reinterpret_cast<int8_t *>(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_,
tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_);
}
auto ret = static_cast<const lite::InnerContext *>(this->context_)
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(ScaleRunInt8, this, op_parameter_->thread_num_);
// free memory malloced from memory pool
if (!scale_param_->const_scale_) {
@ -338,9 +339,9 @@ int ScaleInt8CPUKernel::Run() {
input1_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data_c());
}
if (has_bias_ && !scale_param_->const_offset_) {
input2_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(2)->data_c());
input2_data_ = reinterpret_cast<int8_t *>(in_tensors_.at(kOffsetIndex)->data_c());
}
auto ret = static_cast<const lite::InnerContext *>(this->context_)
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(ScaleRunInt8, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";

View File

@ -27,7 +27,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SliceFusion;
namespace mindspore::kernel {
int SliceInt8CPUKernel::Init() {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);

View File

@ -30,7 +30,6 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Split;
namespace mindspore::kernel {
int SplitInt8CPUKernel::Init() {
auto ret = SplitBaseCPUKernel::Init();
if (ret != RET_OK) {

View File

@ -80,15 +80,15 @@ int TransposeInt8CPUKernel::DoTranspose(int task_id) {
void TransposeInt8CPUKernel::GetNHNCTransposeFunc(lite::Tensor *in_tensor, lite::Tensor *out_tensor,
TransposeParameter *param) {
auto out_shape = out_tensor->shape();
if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 2 && param->perm_[2] == 3 &&
param->perm_[3] == 1) {
if (in_tensor->shape().size() == DIMENSION_4D && param->perm_[0] == 0 && param->perm_[1] == 2 &&
param->perm_[2] == 3 && param->perm_[3] == 1) {
nhnc_param_[0] = out_shape[0];
nhnc_param_[1] = out_shape[1] * out_shape[2];
nhnc_param_[2] = out_shape[3];
NHNCTransposeFunc_ = PackNCHWToNHWCInt8;
}
if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 3 && param->perm_[2] == 1 &&
param->perm_[3] == 2) {
if (in_tensor->shape().size() == DIMENSION_4D && param->perm_[0] == 0 && param->perm_[1] == 3 &&
param->perm_[2] == 1 && param->perm_[3] == 2) {
nhnc_param_[0] = out_shape[0];
nhnc_param_[1] = out_shape[2] * out_shape[3];
nhnc_param_[2] = out_shape[1];

View File

@ -37,7 +37,6 @@ const std::map<std::string, std::string> *kRegexTransforms = new (std::nothrow)
{"i'm", "i am"},
});
const int32_t kMaxStringLength = 300;
} // namespace
int NormalizeCPUKernel::Init() {
@ -49,7 +48,7 @@ int NormalizeCPUKernel::Init() {
int NormalizeCPUKernel::ReSize() { return RET_OK; }
std::string NormalizeCPUKernel::Trim(const std::string &str, const std::string &pattern /*= " \t\n\v\f\r"*/) {
std::string NormalizeCPUKernel::Trim(const std::string &str, const std::string &pattern /* = " \t\n\v\f\r" */) {
auto begin = str.find_first_not_of(pattern);
if (begin == std::string::npos) {
MS_LOG(WARNING) << "Meaningless input string!";

View File

@ -24,6 +24,12 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_CustomPredict;
namespace mindspore::kernel {
namespace {
constexpr int INPUT_INDEX = 0;
constexpr int KEY_INDEX = 1;
constexpr int LABEL_INDEX = 2;
constexpr int WEIGHT_INDEX = 3;
} // namespace
int PredictCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
@ -35,10 +41,10 @@ int PredictCPUKernel::ReSize() { return RET_OK; }
std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() {
std::vector<LabelInfo> label_info_vec;
auto input_tensor = in_tensors_.at(0);
auto keys_tensor = in_tensors_.at(1);
auto labels_tensor = in_tensors_.at(2);
auto weights_tensor = in_tensors_.at(3);
auto input_tensor = in_tensors_.at(INPUT_INDEX);
auto keys_tensor = in_tensors_.at(KEY_INDEX);
auto labels_tensor = in_tensors_.at(LABEL_INDEX);
auto weights_tensor = in_tensors_.at(WEIGHT_INDEX);
int32_t *input = reinterpret_cast<int32_t *>(input_tensor->MutableData());
int32_t *key_begin = reinterpret_cast<int32_t *>(keys_tensor->MutableData());

View File

@ -22,5 +22,4 @@
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
} // namespace mindspore::kernel

View File

@ -43,6 +43,26 @@ int ActivationNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
return RET_ERROR;
}
act_->set_input_x(*npu_inputs[0]);
/*
* mode : Activation mode, with options as follows:
* 0 : Sigmoid
* 1 : ReLU
* 2 : Tanh
* 3 : Clipped ReLU
* 4 : ELU
* 5 : PReLU
* 6 : Abs
* 7 : Relu1
* 8 : Softsign
* 9 : Softplus
* 10 : Hardsigmoid
* 11 : Threshold ReLU
* 12 : Selu
* 13 : Linear
* 14 : Relu6
* 15 : GeLU.
* Defaults to 1 (ReLU). 1.
*/
switch (act_param_->type_) {
case schema::ActivationType_SIGMOID:
act_->set_attr_mode(0);

View File

@ -44,6 +44,10 @@ using mindspore::schema::PrimitiveType_NotEqual;
using mindspore::schema::PrimitiveType_SubFusion;
namespace mindspore::kernel {
namespace {
constexpr int RELU_MODE = 1;
constexpr int RELU6_MODE = 14;
} // namespace
int ArithmeticNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
if (inputs[0]->shape() != inputs[1]->shape()) {
@ -84,9 +88,9 @@ int ArithmeticNPUKernel::SetActivation() {
}
act_->set_input_x(*op_);
if (activation_type_ == ActivationType_RELU) {
act_->set_attr_mode(1);
act_->set_attr_mode(RELU_MODE);
} else if (activation_type_ == ActivationType_RELU6) {
act_->set_attr_mode(14);
act_->set_attr_mode(RELU6_MODE);
} else {
MS_LOG(ERROR) << "Unsupported activation type for op " << name_;
return RET_ERROR;

View File

@ -19,6 +19,16 @@
#include "nnacl/pack.h"
namespace mindspore::kernel {
namespace {
constexpr int BATCH_INDEX = 0;
constexpr int HEIGHT_INDEX = 1;
constexpr int WIDTH_INDEX = 2;
constexpr int CHANNEL_INDEX = 3;
constexpr size_t WITH_BIAS_SIZE = 3;
constexpr int BIAS_INDEX = 2;
constexpr int RELU_MODE = 1;
constexpr int RELU6_MODE = 14;
} // namespace
ConvolutionBaseNPUKernel::~ConvolutionBaseNPUKernel() {
if (act_ != nullptr) {
delete act_;
@ -47,14 +57,16 @@ int ConvolutionBaseNPUKernel::InitWeightConst(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
PackNHWCToNCHWFp32(nhwc_data, nchw_data, w_shape[0], w_shape[1] * w_shape[2], w_shape[3], 0, 0);
PackNHWCToNCHWFp32(nhwc_data, nchw_data, w_shape[BATCH_INDEX], w_shape[HEIGHT_INDEX] * w_shape[WIDTH_INDEX],
w_shape[CHANNEL_INDEX], 0, 0);
std::shared_ptr<ge::Tensor> weight_tensor = std::shared_ptr<ge::Tensor>(new (std::nothrow) ge::Tensor());
std::shared_ptr<ge::Tensor> weight_tensor = std::make_shared<ge::Tensor>();
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "new weight_tensor failed.";
return RET_ERROR;
}
ge::TensorDesc tensor_desc(lite::ConverterToNPUShape({w_shape[0], w_shape[3], w_shape[1], w_shape[2]}),
ge::TensorDesc tensor_desc(lite::ConverterToNPUShape({w_shape[BATCH_INDEX], w_shape[CHANNEL_INDEX],
w_shape[HEIGHT_INDEX], w_shape[WIDTH_INDEX]}),
ge::FORMAT_NCHW, lite::ConverterToNPUDataType(inputs[1]->data_type()));
weight_tensor->SetTensorDesc(tensor_desc);
weight_tensor->SetData(reinterpret_cast<const uint8_t *>(nchw_data), inputs[1]->Size());
@ -65,16 +77,16 @@ int ConvolutionBaseNPUKernel::InitWeightConst(const std::vector<lite::Tensor *>
}
int ConvolutionBaseNPUKernel::InitBiasConst(const std::vector<lite::Tensor *> &inputs) {
if (inputs.size() >= 3) {
if (inputs.size() >= WITH_BIAS_SIZE) {
bias_ = new (std::nothrow) hiai::op::Const(name_ + "_b");
if (bias_ == nullptr) {
MS_LOG(ERROR) << "New bias const failed.";
return RET_ERROR;
}
inputs[2]->set_format(schema::Format_NCHW);
auto bias_tensor = mindspore::lite::ConverterToNPUTensor(inputs[2]);
inputs[BIAS_INDEX]->set_format(schema::Format_NCHW);
auto bias_tensor = mindspore::lite::ConverterToNPUTensor(inputs[BIAS_INDEX]);
bias_->set_attr_value(bias_tensor);
inputs[2]->set_format(schema::Format_NHWC);
inputs[BIAS_INDEX]->set_format(schema::Format_NHWC);
}
return RET_OK;
}
@ -87,9 +99,9 @@ int ConvolutionBaseNPUKernel::SetActivation(const ge::Operator *input, ActType a
}
act_->set_input_x(*input);
if (act_type == ActType_Relu) {
act_->set_attr_mode(1);
act_->set_attr_mode(RELU_MODE);
} else if (act_type == ActType_Relu6) {
act_->set_attr_mode(14);
act_->set_attr_mode(RELU6_MODE);
} else {
MS_LOG(ERROR) << "Unsupported activation type for convolution.";
return RET_ERROR;

View File

@ -22,14 +22,18 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
namespace {
constexpr size_t WITH_AXES_SIZE = 3;
constexpr int AXIS_INDEX = 2;
} // namespace
int GatherNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (inputs[1]->data_type() != kNumberTypeInt32) {
MS_LOG(WARNING) << "Gather indices only support Int32";
return RET_ERROR;
}
if (inputs.size() >= 3 && inputs[2]->ElementsNum() == 1) {
axis_ = static_cast<int *>(inputs[2]->data_c())[0];
if (inputs.size() >= WITH_AXES_SIZE && inputs[AXIS_INDEX]->ElementsNum() == 1) {
axis_ = static_cast<int *>(inputs[AXIS_INDEX]->data_c())[0];
} else {
MS_LOG(WARNING) << "NPU axis is attribute.";
return RET_ERROR;

View File

@ -24,6 +24,10 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_InstanceNorm;
namespace mindspore::kernel {
namespace {
constexpr int GAMMA_INDEX = 1;
constexpr int BETA_INDEX = 2;
} // namespace
int InstanceNormNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
return RET_OK;
@ -39,16 +43,16 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input
}
op_->set_input_x(*npu_inputs[0]);
auto gamma_shape = inputs[1]->shape();
std::shared_ptr<ge::Tensor> gamma_tensor = std::shared_ptr<ge::Tensor>(new (std::nothrow) ge::Tensor());
auto gamma_shape = inputs[GAMMA_INDEX]->shape();
std::shared_ptr<ge::Tensor> gamma_tensor = std::make_shared<ge::Tensor>();
if (gamma_tensor == nullptr) {
MS_LOG(ERROR) << "new gamma_tensor failed.";
return RET_ERROR;
}
ge::TensorDesc gamma_tensor_desc(lite::ConverterToNPUShape({1, gamma_shape[0], 1, 1}), ge::FORMAT_NCHW,
lite::ConverterToNPUDataType(inputs[1]->data_type()));
lite::ConverterToNPUDataType(inputs[GAMMA_INDEX]->data_type()));
gamma_tensor->SetTensorDesc(gamma_tensor_desc);
gamma_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[1]->data_c()), inputs[1]->Size());
gamma_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[GAMMA_INDEX]->data_c()), inputs[GAMMA_INDEX]->Size());
gamma_ = new (std::nothrow) hiai::op::Const(name_ + "_gamma");
if (gamma_ == nullptr) {
MS_LOG(ERROR) << "New gamma_ const failed.";
@ -57,16 +61,16 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input
gamma_->set_attr_value(gamma_tensor);
op_->set_input_gamma(*gamma_);
auto beta_shape = inputs[2]->shape();
std::shared_ptr<ge::Tensor> beta_tensor = std::shared_ptr<ge::Tensor>(new (std::nothrow) ge::Tensor());
auto beta_shape = inputs[BETA_INDEX]->shape();
std::shared_ptr<ge::Tensor> beta_tensor = std::make_shared<ge::Tensor>();
if (beta_tensor == nullptr) {
MS_LOG(ERROR) << "new beta_tensor failed.";
return RET_ERROR;
}
ge::TensorDesc beta_tensor_desc(lite::ConverterToNPUShape({1, beta_shape[0], 1, 1}), ge::FORMAT_NCHW,
lite::ConverterToNPUDataType(inputs[2]->data_type()));
lite::ConverterToNPUDataType(inputs[BETA_INDEX]->data_type()));
beta_tensor->SetTensorDesc(beta_tensor_desc);
beta_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size());
beta_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[BETA_INDEX]->data_c()), inputs[BETA_INDEX]->Size());
beta_ = new (std::nothrow) hiai::op::Const(name_ + "_beta");
if (beta_ == nullptr) {
MS_LOG(ERROR) << "New beta_ const failed.";

View File

@ -47,10 +47,11 @@ int PadNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
// padding shape is spatial_dim x 2.
int size = static_cast<int>(param_->padding_length / 2);
ge::TensorDesc padding_tensor_desc(ge::Shape({size, 2}), ge::FORMAT_NCHW, ge::DT_INT32);
ge::TensorPtr padding_tensor = std::make_shared<hiai::Tensor>(padding_tensor_desc);
padding_tensor->SetData(reinterpret_cast<uint8_t *>(param_->paddings_), 2 * size * sizeof(int));
padding_tensor->SetData(reinterpret_cast<uint8_t *>(param_->paddings_), param_->padding_length * sizeof(int));
hiai_paddings_ = new hiai::op::Const(name_ + "paddings");
hiai_paddings_->set_attr_value(padding_tensor);

View File

@ -23,6 +23,13 @@ using mindspore::schema::PrimitiveType_AvgPoolFusion;
using mindspore::schema::PrimitiveType_MaxPoolFusion;
namespace mindspore::kernel {
namespace {
constexpr int MAX_MODE = 0;
constexpr int AVG_MODE = 1;
constexpr int L2_MODE = 2;
constexpr int PAD_MODE_SAME = 6;
constexpr int PAD_MODE_VALID = 5;
} // namespace
int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (pooling_param_->pad_l_ > pooling_param_->stride_w_ || pooling_param_->pad_u_ > pooling_param_->stride_h_) {
@ -34,20 +41,20 @@ int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const
int PoolingNPUKernel::SetPoolingParam() {
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
pooling_->set_attr_mode(0);
pooling_->set_attr_mode(MAX_MODE);
} else if (pooling_param_->pool_mode_ == PoolMode_AvgPool) {
pooling_->set_attr_mode(1);
pooling_->set_attr_mode(AVG_MODE);
} else {
pooling_->set_attr_mode(2);
pooling_->set_attr_mode(L2_MODE);
}
pooling_->set_attr_global_pooling(pooling_param_->global_);
pooling_->set_attr_window({pooling_param_->window_h_, pooling_param_->window_w_});
pooling_->set_attr_stride({pooling_param_->stride_h_, pooling_param_->stride_w_});
if (pooling_param_->pad_mode_ == Pad_same) {
pooling_->set_attr_pad_mode(6);
pooling_->set_attr_pad_mode(PAD_MODE_SAME);
pooling_->set_attr_pad({0, 0, 0, 0});
} else if (pooling_param_->pad_mode_ == Pad_valid) {
pooling_->set_attr_pad_mode(5);
pooling_->set_attr_pad_mode(PAD_MODE_VALID);
pooling_->set_attr_pad({0, 0, 0, 0});
} else {
pooling_->set_attr_pad_mode(0);

View File

@ -25,6 +25,9 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Resize;
namespace mindspore::kernel {
namespace {
constexpr size_t SIZE_TENSOR_DIMS = 2;
} // namespace
int ResizeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (resize_parameter_->method_ != schema::ResizeMethod_LINEAR &&
@ -45,32 +48,32 @@ int ResizeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
ge::TensorPtr sizeTensor = std::make_shared<hiai::Tensor>(sizeTensorDesc);
vector<int32_t> dataValue = {static_cast<int32_t>(resize_parameter_->new_height_),
static_cast<int32_t>(resize_parameter_->new_width_)};
sizeTensor->SetData(reinterpret_cast<uint8_t *>(dataValue.data()), 2 * sizeof(int32_t));
sizeTensor->SetData(reinterpret_cast<uint8_t *>(dataValue.data()), SIZE_TENSOR_DIMS * sizeof(int32_t));
out_size_ = new (std::nothrow) hiai::op::Const(name_ + "_size");
out_size_->set_attr_value(sizeTensor);
if (resize_parameter_->method_ == schema::ResizeMethod_LINEAR) {
auto op = new (std::nothrow) hiai::op::ResizeBilinearV2(name_);
if (op == nullptr) {
MS_LOG(ERROR) << " op is nullptr.";
auto linear_op = new (std::nothrow) hiai::op::ResizeBilinearV2(name_);
if (linear_op == nullptr) {
MS_LOG(ERROR) << " linear_op is nullptr.";
return RET_ERROR;
}
op->set_attr_align_corners(resize_parameter_->coordinate_transform_mode_ ==
linear_op->set_attr_align_corners(resize_parameter_->coordinate_transform_mode_ ==
schema::CoordinateTransformMode_ALIGN_CORNERS);
op->set_input_x(*npu_inputs[0]);
op->set_input_size(*out_size_);
op->set_attr_half_pixel_centers(resize_parameter_->preserve_aspect_ratio_);
op_ = op;
linear_op->set_input_x(*npu_inputs[0]);
linear_op->set_input_size(*out_size_);
linear_op->set_attr_half_pixel_centers(resize_parameter_->preserve_aspect_ratio_);
op_ = linear_op;
} else if (resize_parameter_->method_ == schema::ResizeMethod_NEAREST) {
auto op = new (std::nothrow) hiai::op::ResizeNearestNeighborV2(name_);
if (op == nullptr) {
MS_LOG(ERROR) << " op is nullptr.";
auto nearest_op = new (std::nothrow) hiai::op::ResizeNearestNeighborV2(name_);
if (nearest_op == nullptr) {
MS_LOG(ERROR) << " nearest_op is nullptr.";
return RET_ERROR;
}
op->set_attr_align_corners(resize_parameter_->coordinate_transform_mode_ ==
nearest_op->set_attr_align_corners(resize_parameter_->coordinate_transform_mode_ ==
schema::CoordinateTransformMode_ALIGN_CORNERS);
op->set_input_x(*npu_inputs[0]);
op->set_input_size(*out_size_);
op_ = op;
nearest_op->set_input_x(*npu_inputs[0]);
nearest_op->set_input_size(*out_size_);
op_ = nearest_op;
} else {
MS_LOG(WARNING) << "Unsupported resize method type:" << resize_parameter_->method_;
return RET_ERROR;

View File

@ -25,12 +25,20 @@ using mindspore::schema::Format_NHWC;
using mindspore::schema::PrimitiveType_ScaleFusion;
namespace mindspore::kernel {
namespace {
constexpr int DIMS_4D = 4;
constexpr int BIAS_INDEX = 2;
constexpr size_t NONE_BIAS_SIZE = 2;
constexpr int RELU_MODE = 1;
constexpr int RELU6_MODE = 14;
} // namespace
int ScaleNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (scale_parameter_->axis_ < 0) {
scale_parameter_->axis_ = scale_parameter_->axis_ + inputs[0]->shape().size();
}
if (inputs.size() > 1 && inputs[0]->shape().size() == 4 && inputs[0]->format() == schema::Format_NHWC) {
if (inputs.size() > 1 && inputs[0]->shape().size() == DIMS_4D && inputs[0]->format() == schema::Format_NHWC) {
// scale now only supports on axis 3
if (scale_parameter_->axis_ != 3) {
MS_LOG(ERROR) << "Npu scale axis attr only support on channel, now is " << scale_parameter_->axis_;
return RET_ERROR;
@ -56,7 +64,7 @@ int ScaleNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, cons
MS_ASSERT(inputs.size() > 1);
auto scale_shape = inputs.at(1)->shape();
std::shared_ptr<ge::Tensor> scale_tensor = std::shared_ptr<ge::Tensor>(new (std::nothrow) ge::Tensor());
std::shared_ptr<ge::Tensor> scale_tensor = std::make_shared<ge::Tensor>();
if (scale_tensor == nullptr) {
MS_LOG(ERROR) << "new scale_tensor failed.";
return RET_ERROR;
@ -73,17 +81,19 @@ int ScaleNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, cons
scale_->set_attr_value(scale_tensor);
op_->set_input_scale(*scale_);
if (inputs.size() > 2 && inputs[2] != nullptr) {
auto bias_shape = inputs[2]->shape();
std::shared_ptr<ge::Tensor> bias_tensor = std::shared_ptr<ge::Tensor>(new (std::nothrow) ge::Tensor());
// inputs size can be larger than 2 when optional bias is provided.
// bias index 2
if (inputs.size() > NONE_BIAS_SIZE && inputs[BIAS_INDEX] != nullptr) {
auto bias_shape = inputs[BIAS_INDEX]->shape();
std::shared_ptr<ge::Tensor> bias_tensor = std::make_shared<ge::Tensor>();
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "new bias_tensor failed.";
return RET_ERROR;
}
ge::TensorDesc bias_tensor_desc(lite::ConverterToNPUShape({1, bias_shape[0], 1, 1}), ge::FORMAT_NCHW,
lite::ConverterToNPUDataType(inputs[2]->data_type()));
lite::ConverterToNPUDataType(inputs[BIAS_INDEX]->data_type()));
bias_tensor->SetTensorDesc(bias_tensor_desc);
bias_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size());
bias_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[BIAS_INDEX]->data_c()), inputs[BIAS_INDEX]->Size());
bias_ = new (std::nothrow) hiai::op::Const(name_ + "_beta");
if (bias_ == nullptr) {
MS_LOG(ERROR) << "New beta_ const failed.";
@ -120,9 +130,9 @@ int ScaleNPUKernel::SetActivation(const ge::Operator *input, int act_type) {
}
act_->set_input_x(*input);
if (act_type == schema::ActivationType_RELU) {
act_->set_attr_mode(1);
act_->set_attr_mode(RELU_MODE);
} else if (act_type == schema::ActivationType_RELU6) {
act_->set_attr_mode(14);
act_->set_attr_mode(RELU6_MODE);
} else {
MS_LOG(ERROR) << "Unsupported activation type for scale.";
return RET_ERROR;

View File

@ -22,6 +22,11 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_SliceFusion;
namespace mindspore::kernel {
namespace {
constexpr int INPUT_INDEX = 0;
constexpr int OFFSET_INDEX = 1;
constexpr int SIZE_INDEX = 2;
} // namespace
int SliceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
return RET_OK;
@ -34,9 +39,9 @@ int SliceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, cons
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
op_->set_input_x(*npu_inputs[0]);
op_->set_input_offsets(*npu_inputs[1]);
op_->set_input_size(*npu_inputs[2]);
op_->set_input_x(*npu_inputs[INPUT_INDEX]);
op_->set_input_offsets(*npu_inputs[OFFSET_INDEX]);
op_->set_input_size(*npu_inputs[SIZE_INDEX]);
return RET_OK;
}

View File

@ -22,14 +22,20 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_StridedSlice;
namespace mindspore::kernel {
namespace {
constexpr int INPUT_INDEX = 0;
constexpr int BEGIN_INDEX = 1;
constexpr int END_INDEX = 2;
constexpr int AXES_INDEX = 3;
} // namespace
int StridedSliceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
// Only onnx StridedSlice has 5 inputs, of which the 4th input is axes and the 5th input is strides.
if (inputs.size() == 5) {
vector<int> axes;
size_t size = inputs[3]->shape()[0];
size_t size = inputs[AXES_INDEX]->shape()[0];
axes.resize(size);
memcpy(axes.data(), inputs[3]->data_c(), sizeof(int) * size);
memcpy(axes.data(), inputs[AXES_INDEX]->data_c(), sizeof(int) * size);
for (int i = 0; i < axes.size(); ++i) {
if (i != axes[i]) {
MS_LOG(ERROR) << "Does not support setting axis, so the axis must be continuous.";
@ -49,14 +55,16 @@ int StridedSliceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
op_->set_input_x(*npu_inputs[0]);
op_->set_input_begin(*npu_inputs[1]);
op_->set_input_end(*npu_inputs[2]);
op_->set_input_x(*npu_inputs[INPUT_INDEX]);
op_->set_input_begin(*npu_inputs[BEGIN_INDEX]);
op_->set_input_end(*npu_inputs[END_INDEX]);
// The strides position of onnx is the 5th, and the others are the 4th.
// For onnx models, input size 5.
if (npu_inputs.size() == 5) {
// The strides position of onnx models is the 5th, index 4.
op_->set_input_strides(*npu_inputs[4]);
} else {
// The strides position of other models are the 4th, index 3.
op_->set_input_strides(*npu_inputs[3]);
}
op_->set_attr_begin_mask(param_->begins_mask_);

View File

@ -22,13 +22,17 @@ using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Transpose;
namespace mindspore::kernel {
namespace {
constexpr size_t TRANSPOSE_INPUT_SIZE = 2;
} // namespace
int TransposeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (conjugate_) {
MS_LOG(ERROR) << "Unsupported conjugate transpose.";
return RET_ERROR;
}
if (inputs.size() >= 2 && inputs[1]->data_c() != nullptr) {
if (inputs.size() == TRANSPOSE_INPUT_SIZE && inputs[1]->data_c() != nullptr) {
for (int i = 0; i < inputs[1]->ElementsNum(); i++) {
perm_.push_back(static_cast<int *>(inputs[1]->data_c())[i]);
}