forked from mindspore-Ecosystem/mindspore
!12508 [MS][LITE][CPU]fix bug of pad_fp16 op
From: @fuzhiye Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
b16fcc8037
|
@ -26,6 +26,9 @@ using mindspore::lite::RET_OK;
|
||||||
using mindspore::schema::PrimitiveType_Pad;
|
using mindspore::schema::PrimitiveType_Pad;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kPadMaxInputSize = 2;
|
||||||
|
} // namespace
|
||||||
int PadFp16CPUKernel::RunImpl(int task_id) {
|
int PadFp16CPUKernel::RunImpl(int task_id) {
|
||||||
PadFp16(input_, output_, in_, out_, pad_param_->paddings_, task_id, context_->thread_num_);
|
PadFp16(input_, output_, in_, out_, pad_param_->paddings_, task_id, context_->thread_num_);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -48,6 +51,9 @@ int PadFp16CPUKernel::Run() {
|
||||||
|
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
|
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
|
||||||
|
if (in_tensors_.size() == kPadMaxInputSize) {
|
||||||
|
CopyPaddingFromInput();
|
||||||
|
}
|
||||||
if (pad_param_->constant_value_ - 0.0f < 1e-5) {
|
if (pad_param_->constant_value_ - 0.0f < 1e-5) {
|
||||||
memset(output_, 0, output_tensor->ElementsNum() * sizeof(float16_t));
|
memset(output_, 0, output_tensor->ElementsNum() * sizeof(float16_t));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -47,7 +47,6 @@ class PadCPUKernel : public LiteKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int CheckPaddings(int *paddings, int length, int *input_shape, int mode);
|
int CheckPaddings(int *paddings, int length, int *input_shape, int mode);
|
||||||
int CopyPaddingFromInput();
|
|
||||||
void CalculateStrides();
|
void CalculateStrides();
|
||||||
int ExtendShape(int *shape, int length, const int *ori_shape, int rank);
|
int ExtendShape(int *shape, int length, const int *ori_shape, int rank);
|
||||||
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length);
|
int ExtendPaddings(int *paddings, int length, const int *ori_paddings, int ori_length);
|
||||||
|
@ -55,6 +54,7 @@ class PadCPUKernel : public LiteKernel {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int HandleMirrorPad();
|
int HandleMirrorPad();
|
||||||
|
int CopyPaddingFromInput();
|
||||||
PadParameter *pad_param_ = nullptr;
|
PadParameter *pad_param_ = nullptr;
|
||||||
int in_[4] = {0};
|
int in_[4] = {0};
|
||||||
int out_[4] = {0};
|
int out_[4] = {0};
|
||||||
|
|
Loading…
Reference in New Issue