forked from mindspore-Ecosystem/mindspore
!9882 [MSLITE][DEVELOP] add npu op conv, conv_depthwise, pooling, activation
From: @yangruoqi713 Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
aeeed2b8f6
|
@ -49,6 +49,7 @@ typedef struct ConvParameter {
|
|||
int thread_num_;
|
||||
int input_unit_;
|
||||
int output_unit_;
|
||||
PadMode pad_mode_;
|
||||
ActType act_type_;
|
||||
} ConvParameter;
|
||||
|
||||
|
|
|
@ -79,6 +79,7 @@ typedef struct OpParameter {
|
|||
} OpParameter;
|
||||
|
||||
typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType;
|
||||
typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode;
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
#define MS_FLOAT32X4 float32x4_t
|
||||
|
|
|
@ -28,6 +28,7 @@ typedef struct PoolingParameter {
|
|||
OpParameter op_parameter_;
|
||||
PoolMode pool_mode_;
|
||||
RoundMode round_mode_;
|
||||
PadMode pad_mode_;
|
||||
ActType act_type_;
|
||||
int avg_mode_;
|
||||
bool global_;
|
||||
|
|
|
@ -48,6 +48,18 @@ OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
conv_param->input_channel_ = conv_primitive->GetChannelIn();
|
||||
conv_param->output_channel_ = conv_primitive->GetChannelOut();
|
||||
conv_param->group_ = conv_primitive->GetGroup();
|
||||
auto pad_mode = conv_primitive->GetPadMode();
|
||||
switch (pad_mode) {
|
||||
case schema::PadMode_SAME_UPPER:
|
||||
conv_param->pad_mode_ = Pad_Same;
|
||||
break;
|
||||
case schema::PadMode_VALID:
|
||||
conv_param->pad_mode_ = Pad_Valid;
|
||||
break;
|
||||
default:
|
||||
conv_param->pad_mode_ = Pad_No;
|
||||
break;
|
||||
}
|
||||
auto act_type = conv_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
|
|
|
@ -46,6 +46,18 @@ OpParameter *PopulateConvDwParameter(const mindspore::lite::PrimitiveC *primitiv
|
|||
conv_param->input_channel_ = convdw_lite_primitive->GetInputChannel();
|
||||
conv_param->dilation_h_ = conv_primitive->GetDilateH();
|
||||
conv_param->dilation_w_ = conv_primitive->GetDilateW();
|
||||
auto pad_mode = conv_primitive->GetPadMode();
|
||||
switch (pad_mode) {
|
||||
case schema::PadMode_SAME_UPPER:
|
||||
conv_param->pad_mode_ = Pad_Same;
|
||||
break;
|
||||
case schema::PadMode_VALID:
|
||||
conv_param->pad_mode_ = Pad_Valid;
|
||||
break;
|
||||
default:
|
||||
conv_param->pad_mode_ = Pad_No;
|
||||
break;
|
||||
}
|
||||
auto act_type = conv_primitive->GetActivationType();
|
||||
switch (act_type) {
|
||||
case schema::ActivationType_RELU:
|
||||
|
|
|
@ -31,6 +31,7 @@ ge::Format ConverterToNPUFormat(schema::Format format) {
|
|||
ge_format = ge::FORMAT_NCHW;
|
||||
break;
|
||||
case schema::Format_NHWC:
|
||||
case schema::Format_KHWC:
|
||||
ge_format = ge::FORMAT_NHWC;
|
||||
break;
|
||||
default:
|
||||
|
@ -79,7 +80,7 @@ hiai::op::Data *ConverterToNPUData(Tensor *src, const std::string &name) {
|
|||
MS_LOG(ERROR) << "new data failed.";
|
||||
return data;
|
||||
}
|
||||
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()),
|
||||
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW,
|
||||
ConverterToNPUDataType(src->data_type()));
|
||||
data->update_input_desc_x(tensor_desc);
|
||||
return data;
|
||||
|
@ -91,7 +92,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) {
|
|||
MS_LOG(ERROR) << "new ge_tensor failed.";
|
||||
return ge_tensor;
|
||||
}
|
||||
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()),
|
||||
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW,
|
||||
ConverterToNPUDataType(src->data_type()));
|
||||
|
||||
ge_tensor->SetTensorDesc(tensor_desc);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/agent/npu/subgraph_npu_kernel.h"
|
||||
#include <set>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/agent/npu/npu_executor.h"
|
||||
#include "include/graph/operator.h"
|
||||
|
@ -72,6 +73,9 @@ domi::ModelBufferData *SubGraphNpuKernel::BuildIRModel() {
|
|||
int SubGraphNpuKernel::Run() { return this->executor_->Run(in_tensors_, out_tensors_, nodes_, nullptr); }
|
||||
|
||||
int SubGraphNpuKernel::BuildNPUInputOp() {
|
||||
std::set<schema::PrimitiveType> trans_nodes = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_DeDepthwiseConv2D};
|
||||
int count = 0;
|
||||
subgraph_input_op_.clear();
|
||||
for (auto node : this->nodes_) {
|
||||
|
@ -79,9 +83,16 @@ int SubGraphNpuKernel::BuildNPUInputOp() {
|
|||
for (auto in_tensor : node->in_tensors()) {
|
||||
if (IsSubGraphInputTensor(in_tensor)) {
|
||||
auto tensor_name = node->name() + "_" + std::to_string(count++);
|
||||
auto shape = in_tensor->shape();
|
||||
if (trans_nodes.find(node->Type()) != trans_nodes.end()) {
|
||||
in_tensor->set_shape({shape[0], shape[3], shape[1], shape[2]});
|
||||
}
|
||||
auto data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name);
|
||||
subgraph_input_op_.push_back(*data);
|
||||
node_input_op.push_back(data);
|
||||
if (trans_nodes.find(node->Type()) != trans_nodes.end()) {
|
||||
in_tensor->set_shape(shape);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -183,5 +194,4 @@ int SubGraphNpuKernel::Prepare() {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -26,7 +26,6 @@ using mindspore::lite::RET_ERROR;
|
|||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType;
|
||||
using mindspore::schema::PadMode;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() {
|
||||
|
|
|
@ -30,8 +30,6 @@
|
|||
#include "src/runtime/kernel/arm/base/layout_transform.h"
|
||||
|
||||
using mindspore::lite::InnerContext;
|
||||
using mindspore::schema::PadMode;
|
||||
using mindspore::schema::QuantType;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionBaseCPUKernel : public LiteKernel {
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/activation.h"
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ActivationNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
|
||||
if (act_param_->type_ != schema::ActivationType_RELU && act_param_->type_ != schema::ActivationType_RELU6 &&
|
||||
act_param_->type_ != schema::ActivationType_SIGMOID && act_param_->type_ != schema::ActivationType_TANH &&
|
||||
act_param_->type_ != schema::ActivationType_HSIGMOID && act_param_->type_ != schema::ActivationType_LEAKY_RELU) {
|
||||
MS_LOG(ERROR) << "Unsupport activation type for activation op " << name_ << "when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) {
|
||||
act_ = new (std::nothrow) hiai::op::Activation(name_ + "_act");
|
||||
if (act_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New activation npu operator for activation op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
act_->set_input_x(*npu_inputs[0]);
|
||||
switch (act_param_->type_) {
|
||||
case schema::ActivationType_SIGMOID:
|
||||
act_->set_attr_mode(0);
|
||||
break;
|
||||
case schema::ActivationType_RELU:
|
||||
act_->set_attr_mode(1);
|
||||
break;
|
||||
case schema::ActivationType_TANH:
|
||||
act_->set_attr_mode(2);
|
||||
break;
|
||||
case schema::ActivationType_LEAKY_RELU:
|
||||
act_->set_attr_mode(5);
|
||||
act_->set_attr_negative_slope(act_param_->alpha_);
|
||||
break;
|
||||
case schema::ActivationType_HSIGMOID:
|
||||
act_->set_attr_mode(10);
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
act_->set_attr_mode(14);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport activation type for activation op " << name_ << "when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::ActivationNPUKernel::GetNPUOp() { return act_; }
|
||||
|
||||
ActivationNPUKernel::~ActivationNPUKernel() {
|
||||
if (act_ != nullptr) {
|
||||
delete act_;
|
||||
act_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Activation, NPUKernelCreator<ActivationNPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_ACTIVATION_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_ACTIVATION_NPU_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "include/graph/compatible/all_ops.h"
|
||||
#include "src/runtime/kernel/npu/npu_kernel.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ActivationNPUKernel : public NPUKernel {
|
||||
public:
|
||||
ActivationNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
act_param_ = reinterpret_cast<ActivationParameter *>(parameter);
|
||||
}
|
||||
~ActivationNPUKernel() override;
|
||||
|
||||
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) override;
|
||||
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) override;
|
||||
ge::Operator *GetNPUOp() override;
|
||||
|
||||
private:
|
||||
hiai::op::Activation *act_ = nullptr;
|
||||
ActivationParameter *act_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_ACTIVATION_NPU_H_
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/convolution_base_npu.h"
|
||||
#include "src/runtime/agent/npu/npu_converter_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionBaseNPUKernel::~ConvolutionBaseNPUKernel() {
|
||||
if (act_ != nullptr) {
|
||||
delete act_;
|
||||
act_ = nullptr;
|
||||
}
|
||||
if (weight_ != nullptr) {
|
||||
delete weight_;
|
||||
weight_ = nullptr;
|
||||
}
|
||||
if (bias_ != nullptr) {
|
||||
delete bias_;
|
||||
bias_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionBaseNPUKernel::InitWeightBiasConst(const std::vector<lite::Tensor *> &inputs) {
|
||||
weight_ = new (std::nothrow) hiai::op::Const(name_ + "_w");
|
||||
if (weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New weight const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_shape = inputs[1]->shape();
|
||||
inputs[1]->set_shape({weight_shape[0], weight_shape[3], weight_shape[1], weight_shape[2]});
|
||||
inputs[1]->set_format(schema::Format_NCHW);
|
||||
auto weight_tensor = mindspore::lite::ConverterToNPUTensor(inputs[1]);
|
||||
weight_->set_attr_value(weight_tensor);
|
||||
|
||||
inputs[1]->set_shape(weight_shape);
|
||||
inputs[1]->set_format(schema::Format_NHWC);
|
||||
|
||||
if (inputs.size() >= 3) {
|
||||
bias_ = new (std::nothrow) hiai::op::Const(name_ + "_b");
|
||||
if (bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New bias const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
inputs[2]->set_format(schema::Format_NCHW);
|
||||
auto bias_tensor = mindspore::lite::ConverterToNPUTensor(inputs[2]);
|
||||
bias_->set_attr_value(bias_tensor);
|
||||
inputs[2]->set_format(schema::Format_NHWC);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionBaseNPUKernel::SetActivation(const ge::Operator *input, ActType act_type) {
|
||||
act_ = new (std::nothrow) hiai::op::Activation(name_ + "_act");
|
||||
if (act_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
act_->set_input_x(*input);
|
||||
if (act_type == ActType_Relu) {
|
||||
act_->set_attr_mode(1);
|
||||
} else if (act_type == ActType_Relu6) {
|
||||
act_->set_attr_mode(14);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport activation for convolution.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_BASE_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_BASE_NPU_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "src/runtime/kernel/npu/transpose_base_npu.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionBaseNPUKernel : public TransposeBaseNPUKernel {
|
||||
public:
|
||||
ConvolutionBaseNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: TransposeBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionBaseNPUKernel() override;
|
||||
|
||||
protected:
|
||||
int InitWeightBiasConst(const std::vector<lite::Tensor *> &inputs);
|
||||
int SetActivation(const ge::Operator *input, ActType act_type);
|
||||
hiai::op::Activation *act_ = nullptr;
|
||||
hiai::op::Const *weight_ = nullptr;
|
||||
hiai::op::Const *bias_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_BASE_NPU_H_
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/convolution_depthwise_npu.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/agent/npu/npu_converter_utils.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ConvolutionDepthwiseNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseNPUKernel::SetConvDwParam() {
|
||||
conv_dw_->set_attr_strides(ge::AttrValue::LIST_INT({conv_param_->stride_h_, conv_param_->stride_w_}));
|
||||
conv_dw_->set_attr_dilations(ge::AttrValue::LIST_INT({conv_param_->dilation_h_, conv_param_->dilation_w_}));
|
||||
|
||||
if (conv_param_->pad_mode_ == Pad_Same) {
|
||||
conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"SAME"});
|
||||
conv_dw_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0}));
|
||||
} else if (conv_param_->pad_mode_ == Pad_Valid) {
|
||||
conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"});
|
||||
conv_dw_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0}));
|
||||
} else {
|
||||
conv_dw_->set_attr_pad_mode(ge::AttrValue::STR{"SPECIFIC"});
|
||||
conv_dw_->set_attr_pads(
|
||||
ge::AttrValue::LIST_INT({conv_param_->pad_u_, conv_param_->pad_d_, conv_param_->pad_l_, conv_param_->pad_r_}));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) {
|
||||
auto ret = SetPreTranspose(npu_inputs[0]);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New pre transpose npu operator (NHWC -> NCHW) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// set conv attr param
|
||||
conv_dw_ = new (std::nothrow) hiai::op::ConvolutionDepthwise(name_ + "_conv_depthwise");
|
||||
if (conv_dw_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New convolution depthwise operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = SetConvDwParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set npu op parameter for convolution depthwise op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBiasConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set weight and bias for convolution depthwise op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_dw_->set_input_filter(*weight_);
|
||||
if (inputs.size() == 3) {
|
||||
conv_dw_->set_input_bias(*bias_);
|
||||
}
|
||||
conv_dw_->set_input_x(*pre_trans_);
|
||||
|
||||
if (conv_param_->act_type_ != ActType_No) {
|
||||
ret = SetActivation(conv_dw_, conv_param_->act_type_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (conv_param_->act_type_ == ActType_No) {
|
||||
ret = SetPostTranspose(conv_dw_);
|
||||
} else {
|
||||
ret = SetPostTranspose(act_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New post transpose npu operator (NCHW -> NHWC) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::ConvolutionDepthwiseNPUKernel::GetNPUOp() { return post_trans_; }
|
||||
|
||||
ConvolutionDepthwiseNPUKernel::~ConvolutionDepthwiseNPUKernel() {
|
||||
if (conv_dw_ != nullptr) {
|
||||
delete conv_dw_;
|
||||
conv_dw_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, NPUKernelCreator<ConvolutionDepthwiseNPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_DEPTHWISE_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_DEPTHWISE_NPU_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "include/graph/compatible/all_ops.h"
|
||||
#include "src/runtime/kernel/npu/convolution_base_npu.h"
|
||||
#include "src/runtime/kernel/npu/npu_kernel.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionDepthwiseNPUKernel : public ConvolutionBaseNPUKernel {
|
||||
public:
|
||||
ConvolutionDepthwiseNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
conv_param_ = reinterpret_cast<ConvParameter *>(parameter);
|
||||
}
|
||||
~ConvolutionDepthwiseNPUKernel() override;
|
||||
|
||||
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) override;
|
||||
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) override;
|
||||
ge::Operator *GetNPUOp() override;
|
||||
|
||||
private:
|
||||
int SetConvDwParam();
|
||||
hiai::op::ConvolutionDepthwise *conv_dw_ = nullptr;
|
||||
ConvParameter *conv_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_DEPTHWISE_NPU_H_
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/convolution_npu.h"
|
||||
#include "src/runtime/agent/npu/npu_converter_utils.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_Conv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionNPUKernel::SetConvParam() {
|
||||
conv_->set_attr_strides(ge::AttrValue::LIST_INT({conv_param_->stride_h_, conv_param_->stride_w_}));
|
||||
conv_->set_attr_dilations(ge::AttrValue::LIST_INT({conv_param_->dilation_h_, conv_param_->dilation_w_}));
|
||||
conv_->set_attr_groups(1);
|
||||
|
||||
if (conv_param_->pad_mode_ == Pad_Same) {
|
||||
conv_->set_attr_pad_mode(ge::AttrValue::STR{"SAME"});
|
||||
conv_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0}));
|
||||
} else if (conv_param_->pad_mode_ == Pad_Valid) {
|
||||
conv_->set_attr_pad_mode(ge::AttrValue::STR{"VALID"});
|
||||
conv_->set_attr_pads(ge::AttrValue::LIST_INT({0, 0, 0, 0}));
|
||||
} else {
|
||||
conv_->set_attr_pad_mode(ge::AttrValue::STR{"SPECIFIC"});
|
||||
conv_->set_attr_pads(
|
||||
ge::AttrValue::LIST_INT({conv_param_->pad_u_, conv_param_->pad_d_, conv_param_->pad_l_, conv_param_->pad_r_}));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) {
|
||||
auto ret = SetPreTranspose(npu_inputs[0]);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New pre transpose npu operator (NHWC -> NCHW) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// set conv attr param
|
||||
conv_ = new (std::nothrow) hiai::op::Convolution(name_ + "_conv");
|
||||
if (conv_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New convolution operator for convolution op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = SetConvParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set npu op parameter for convolution op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBiasConst(inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set weight and bias for convolution op " << name_ << " failed when running npu";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_->set_input_filter(*weight_);
|
||||
if (inputs.size() == 3) {
|
||||
conv_->set_input_bias(*bias_);
|
||||
}
|
||||
conv_->set_input_x(*pre_trans_);
|
||||
|
||||
if (conv_param_->act_type_ != ActType_No) {
|
||||
ret = SetActivation(conv_, conv_param_->act_type_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (conv_param_->act_type_ == ActType_No) {
|
||||
ret = SetPostTranspose(conv_);
|
||||
} else {
|
||||
ret = SetPostTranspose(act_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New post transpose npu operator (NCHW -> NHWC) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::ConvolutionNPUKernel::GetNPUOp() { return post_trans_; }
|
||||
|
||||
ConvolutionNPUKernel::~ConvolutionNPUKernel() {
|
||||
if (conv_ != nullptr) {
|
||||
delete conv_;
|
||||
conv_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Conv2D, NPUKernelCreator<ConvolutionNPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_NPU_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "src/runtime/kernel/npu/convolution_base_npu.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionNPUKernel : public ConvolutionBaseNPUKernel {
|
||||
public:
|
||||
ConvolutionNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
conv_param_ = reinterpret_cast<ConvParameter *>(parameter);
|
||||
}
|
||||
~ConvolutionNPUKernel() override;
|
||||
|
||||
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) override;
|
||||
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) override;
|
||||
ge::Operator *GetNPUOp() override;
|
||||
|
||||
private:
|
||||
int SetConvParam();
|
||||
hiai::op::Convolution *conv_ = nullptr;
|
||||
ConvParameter *conv_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_CONVOLUTION_NPU_H_
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/pooling_npu.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_Pooling;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PoolingNPUKernel::SetPoolingParam() {
|
||||
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
|
||||
pooling_->set_attr_mode(0);
|
||||
} else if (pooling_param_->pool_mode_ == PoolMode_AvgPool) {
|
||||
pooling_->set_attr_mode(1);
|
||||
} else {
|
||||
pooling_->set_attr_mode(2);
|
||||
}
|
||||
pooling_->set_attr_global_pooling(pooling_param_->global_);
|
||||
pooling_->set_attr_window({pooling_param_->window_h_, pooling_param_->window_w_});
|
||||
pooling_->set_attr_stride({pooling_param_->stride_h_, pooling_param_->stride_w_});
|
||||
if (pooling_param_->pad_mode_ == Pad_Same) {
|
||||
pooling_->set_attr_pad_mode(6);
|
||||
pooling_->set_attr_pad({0, 0, 0, 0});
|
||||
} else if (pooling_param_->pad_mode_ == Pad_Valid) {
|
||||
pooling_->set_attr_pad_mode(5);
|
||||
pooling_->set_attr_pad({0, 0, 0, 0});
|
||||
} else {
|
||||
pooling_->set_attr_pad_mode(0);
|
||||
pooling_->set_attr_pad(
|
||||
{pooling_param_->pad_u_, pooling_param_->pad_d_, pooling_param_->pad_l_, pooling_param_->pad_r_});
|
||||
}
|
||||
|
||||
if (pooling_param_->round_mode_ == RoundMode_Floor) { // no use in cpu
|
||||
pooling_->set_attr_ceil_mode(0);
|
||||
} else {
|
||||
pooling_->set_attr_ceil_mode(1);
|
||||
}
|
||||
// todo data mode
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PoolingNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) {
|
||||
auto ret = SetPreTranspose(npu_inputs[0]);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New pre transpose npu operator (NHWC -> NCHW) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
pooling_ = new (std::nothrow) hiai::op::PoolingD(name_ + "_pooling");
|
||||
if (pooling_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New pooling npu operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = SetPoolingParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set npu op parameter for convolution op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
pooling_->set_input_x(*pre_trans_);
|
||||
|
||||
if (pooling_param_->act_type_ != ActType_No) {
|
||||
ret = SetActivation(pooling_, pooling_param_->act_type_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (pooling_param_->act_type_ == ActType_No) {
|
||||
ret = SetPostTranspose(pooling_);
|
||||
} else {
|
||||
ret = SetPostTranspose(act_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "New post transpose npu operator (NCHW -> NHWC) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::PoolingNPUKernel::GetNPUOp() { return post_trans_; }
|
||||
|
||||
PoolingNPUKernel::~PoolingNPUKernel() {
|
||||
if (pooling_ != nullptr) {
|
||||
delete pooling_;
|
||||
pooling_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Pooling, NPUKernelCreator<PoolingNPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_POOLING_NPU_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_POOLING_NPU_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "src/runtime/kernel/npu/convolution_base_npu.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class PoolingNPUKernel : public ConvolutionBaseNPUKernel {
|
||||
public:
|
||||
PoolingNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseNPUKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
pooling_param_ = reinterpret_cast<PoolingParameter *>(parameter);
|
||||
}
|
||||
~PoolingNPUKernel() override;
|
||||
|
||||
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) override;
|
||||
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
const std::vector<ge::Operator *> &npu_inputs) override;
|
||||
ge::Operator *GetNPUOp() override;
|
||||
|
||||
private:
|
||||
int SetPoolingParam();
|
||||
hiai::op::PoolingD *pooling_ = nullptr;
|
||||
PoolingParameter *pooling_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_POOLING_NPU_H_
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/transpose_base_npu.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
TransposeBaseNPUKernel::~TransposeBaseNPUKernel() {
|
||||
if (pre_trans_ != nullptr) {
|
||||
delete pre_trans_;
|
||||
pre_trans_ = nullptr;
|
||||
}
|
||||
if (post_trans_ != nullptr) {
|
||||
delete post_trans_;
|
||||
post_trans_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int TransposeBaseNPUKernel::SetPreTranspose(const ge::Operator *input) {
|
||||
// input permute: NHWC -> NCHW
|
||||
pre_trans_ = new (std::nothrow) hiai::op::Permute(name_ + "_pre_transpose");
|
||||
if (pre_trans_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New pre transpose npu operator (NHWC -> NCHW) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
pre_trans_->set_input_x(*input);
|
||||
pre_trans_->set_attr_order(ge::AttrValue::LIST_INT({0, 3, 1, 2}));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TransposeBaseNPUKernel::SetPostTranspose(const ge::Operator *input) {
|
||||
// permute: NCHW -> NHWC
|
||||
post_trans_ = new (std::nothrow) hiai::op::Permute(name_ + "_post_transpose");
|
||||
if (post_trans_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New post transpose operator (NCHW -> NHWC) for op " << name_ << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
post_trans_->set_input_x(*input);
|
||||
post_trans_->set_attr_order(ge::AttrValue::LIST_INT({0, 2, 3, 1}));
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_BASE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/graph/op/all_ops.h"
|
||||
#include "include/graph/compatible/all_ops.h"
|
||||
#include "src/runtime/kernel/npu/npu_kernel.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TransposeBaseNPUKernel : public NPUKernel {
|
||||
public:
|
||||
TransposeBaseNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~TransposeBaseNPUKernel() override;
|
||||
|
||||
protected:
|
||||
int SetPreTranspose(const ge::Operator *input);
|
||||
int SetPostTranspose(const ge::Operator *input);
|
||||
hiai::op::Permute *pre_trans_ = nullptr;
|
||||
hiai::op::Permute *post_trans_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_BASE_H_
|
Loading…
Reference in New Issue